4444 Optional [DictConfig | AlgoConfig ], # config
4545 torch .Tensor | None , # rollout_log_probs
4646 ],
47- tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ],
47+ tuple [torch .Tensor , dict [ str , Any ] ],
4848]
4949
5050POLICY_LOSS_REGISTRY : dict [str , PolicyLossFn ] = {}
@@ -893,7 +893,7 @@ def compute_policy_loss_vanilla(
893893 loss_agg_mode : str = "token-mean" ,
894894 config : Optional [DictConfig | AlgoConfig ] = None ,
895895 rollout_is_weights : torch .Tensor | None = None ,
896- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
896+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
897897 """
898898 Compute the clipped policy objective and related metrics for PPO.
899899
@@ -968,7 +968,12 @@ def compute_policy_loss_vanilla(
968968
969969 pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
970970
971- return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
971+ pg_metrics = {
972+ "actor/pg_clipfrac" : pg_clipfrac .detach ().item (),
973+ "actor/ppo_kl" : ppo_kl .detach ().item (),
974+ "actor/pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
975+ }
976+ return pg_loss , pg_metrics
972977
973978
974979@register_policy_loss ("gspo" )
@@ -980,7 +985,7 @@ def compute_policy_loss_gspo(
980985 loss_agg_mode : str = "seq-mean-token-mean" ,
981986 config : Optional [DictConfig | ActorConfig ] = None ,
982987 rollout_is_weights : torch .Tensor | None = None ,
983- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
988+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
984989 """
985990 Compute the clipped policy objective and related metrics for GSPO.
986991
@@ -1037,8 +1042,12 @@ def compute_policy_loss_gspo(
10371042 pg_clipfrac_lower = torch .tensor (0.0 , device = pg_loss .device )
10381043
10391044 ppo_kl = verl_F .masked_mean (- negative_approx_kl , response_mask )
1040-
1041- return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
1045+ pg_metrics = {
1046+ "actor/pg_clipfrac" : pg_clipfrac .detach ().item (),
1047+ "actor/ppo_kl" : ppo_kl .detach ().item (),
1048+ "actor/pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
1049+ }
1050+ return pg_loss , pg_metrics
10421051
10431052
10441053@register_policy_loss ("gpg" )
@@ -1050,7 +1059,7 @@ def compute_policy_loss_gpg(
10501059 loss_agg_mode : str = "token-mean" ,
10511060 config : Optional [DictConfig | AlgoConfig ] = None ,
10521061 rollout_is_weights : torch .Tensor | None = None ,
1053- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
1062+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
10541063 """Adapted from
10551064 https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
10561065 Args:
@@ -1071,7 +1080,7 @@ def compute_policy_loss_gpg(
10711080 pg_losses = pg_losses * rollout_is_weights
10721081
10731082 pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
1074- return pg_loss , torch . tensor ( 0.0 ), torch . tensor ( 0.0 ), torch . tensor ( 0.0 )
1083+ return pg_loss , {}
10751084
10761085
10771086@register_policy_loss ("clip_cov" )
@@ -1083,7 +1092,7 @@ def compute_policy_loss_clip_cov(
10831092 loss_agg_mode : str = "token-mean" ,
10841093 config : Optional [DictConfig | AlgoConfig ] = None ,
10851094 rollout_is_weights : torch .Tensor | None = None ,
1086- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
1095+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
10871096 """
10881097 Compute the clipped policy objective and related metrics for Clip-Cov.
10891098
@@ -1170,8 +1179,11 @@ def compute_policy_loss_clip_cov(
11701179 pg_losses = pg_losses * rollout_is_weights
11711180
11721181 pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
1173-
1174- return pg_loss , pg_clipfrac , ppo_kl , torch .tensor (0.0 )
1182+ pg_metrics = {
1183+ "actor/pg_clipfrac" : pg_clipfrac .detach ().item (),
1184+ "actor/ppo_kl" : ppo_kl .detach ().item (),
1185+ }
1186+ return pg_loss , pg_metrics
11751187
11761188
11771189@register_policy_loss ("kl_cov" )
@@ -1183,7 +1195,7 @@ def compute_policy_loss_kl_cov(
11831195 loss_agg_mode : str = "token-mean" ,
11841196 config : Optional [DictConfig | AlgoConfig ] = None ,
11851197 rollout_is_weights : torch .Tensor | None = None ,
1186- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
1198+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
11871199 """
11881200 Compute the clipped policy objective and related metrics for Clip-Cov.
11891201
@@ -1246,8 +1258,10 @@ def compute_policy_loss_kl_cov(
12461258 pg_losses = pg_losses * rollout_is_weights
12471259
12481260 pg_loss = agg_loss (loss_mat = pg_losses , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
1249-
1250- return pg_loss , torch .tensor (0.0 ), ppo_kl_abs , torch .tensor (0.0 )
1261+ pg_metrics = {
1262+ "actor/ppo_kl" : ppo_kl_abs .detach ().item (),
1263+ }
1264+ return pg_loss , pg_metrics
12511265
12521266
12531267@register_policy_loss ("geo_mean" )
@@ -1259,7 +1273,7 @@ def compute_policy_loss_geo_mean(
12591273 loss_agg_mode : str = "token-mean" ,
12601274 config : Optional [DictConfig | AlgoConfig ] = None ,
12611275 rollout_is_weights : torch .Tensor | None = None ,
1262- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
1276+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
12631277 """
12641278 Compute the clipped policy objective and related metrics for GMPO.
12651279
@@ -1328,8 +1342,12 @@ def compute_policy_loss_geo_mean(
13281342 clipped = torch .ne (negative_approx_kl , negative_approx_kl_clamp )
13291343 pg_clipfrac = verl_F .masked_mean ((clipped * (advantages > 0 )).float (), response_mask )
13301344 pg_clipfrac_lower = verl_F .masked_mean ((clipped * (advantages < 0 )).float (), response_mask )
1331-
1332- return pg_loss , pg_clipfrac , ppo_kl , pg_clipfrac_lower
1345+ pg_metrics = {
1346+ "actor/pg_clipfrac" : pg_clipfrac .detach ().item (),
1347+ "actor/ppo_kl" : ppo_kl .detach ().item (),
1348+ "actor/pg_clipfrac_lower" : pg_clipfrac_lower .detach ().item (),
1349+ }
1350+ return pg_loss , pg_metrics
13331351
13341352
13351353def compute_entropy_loss (logits , response_mask , loss_agg_mode : str = "token-mean" ):
@@ -1672,12 +1690,14 @@ def compute_policy_loss_with_rollout_correction(
16721690 negative_approx_kl = log_prob - rollout_log_prob
16731691 kl_divergence = verl_F .masked_mean (- negative_approx_kl , effective_mask )
16741692
1675- # No clipping in pure rollout correction mode
1676- clip_fraction = torch .tensor (0.0 )
1693+ pg_metrics = rollout_metrics
1694+ pg_metrics .update (
1695+ {
1696+ "actor/ppo_kl" : kl_divergence .detach ().item (),
1697+ }
1698+ )
16771699
1678- # Return tuple matching compute_policy_loss signature: (loss, clip_fraction, kl, clip_fraction_lower)
1679- # Note: Algorithm metrics (rollout_metrics) should be handled separately by caller
1680- return pg_loss , clip_fraction , kl_divergence , clip_fraction
1700+ return pg_loss , pg_metrics
16811701
16821702
16831703@register_policy_loss ("rollout_correction" )
@@ -1689,7 +1709,7 @@ def compute_policy_loss_rollout_correction_wrapper(
16891709 loss_agg_mode : str = "token-mean" ,
16901710 config : Optional [DictConfig | AlgoConfig ] = None ,
16911711 rollout_is_weights : torch .Tensor | None = None ,
1692- ) -> tuple [torch .Tensor , torch . Tensor , torch . Tensor , torch . Tensor ]:
1712+ ) -> tuple [torch .Tensor , dict [ str , Any ] ]:
16931713 """Wrapper for compute_policy_loss_with_rollout_correction to match PolicyLossFn interface.
16941714
16951715 This function is used when algorithm.rollout_correction.use_pure_rollout_correction=True.
0 commit comments