Skip to content

Commit a7f58cd

Browse files
[algo] feat: return loss and metrics from policy_loss_fn (volcengine#4062)
### What does this PR do? This PR refactors `policy_loss_fn` to return `loss` and `metrics`, allowing more flexible definitions returning any metrics. ### Test See the CI tests. ### API and Usage Example See [core_algos.py](https://github.com/volcengine/verl/blob/main/trainer/ppo/core_algos.py) for examples. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ec01579 commit a7f58cd

File tree

6 files changed

+56
-71
lines changed

6 files changed

+56
-71
lines changed

recipe/flowrl/flowrl_actor.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -356,16 +356,6 @@ def update_policy(self, data: DataProto):
356356
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
357357
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
358358
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
359-
# policy_loss_fn = get_policy_loss_fn(loss_mode)
360-
# pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
361-
# old_log_prob=old_log_prob,
362-
# log_prob=log_prob,
363-
# advantages=advantages,
364-
# response_mask=response_mask,
365-
# loss_agg_mode=loss_agg_mode,
366-
# config=self.config,
367-
# rollout_log_probs=rollout_log_probs,
368-
# )
369359
# Compute FlowRL trajectory balance loss
370360
policy_loss, flowrl_metrics = self.compute_flowrl_objective(
371361
log_prob=log_prob,

tests/trainer/ppo/test_rollout_corr_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
7777
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
7878

7979
# Policy loss function receives pre-computed IS weights
80-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(
80+
pg_loss, _ = compute_policy_loss_vanilla(
8181
old_log_prob=sample_data["old_log_prob"],
8282
log_prob=sample_data["log_prob"],
8383
advantages=sample_data["advantages"],
@@ -234,7 +234,7 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
234234

235235
# In metrics-only mode, we compute loss WITHOUT applying weights
236236
# (simulating rollout_is=False)
237-
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
237+
pg_loss_no_weights, _ = compute_policy_loss_vanilla(
238238
old_log_prob=sample_data["old_log_prob"],
239239
log_prob=sample_data["log_prob"],
240240
advantages=sample_data["advantages"],
@@ -246,7 +246,7 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
246246

247247
# Compare to loss WITH weights (rollout_is=True)
248248
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
249-
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
249+
pg_loss_with_weights, _ = compute_policy_loss_vanilla(
250250
old_log_prob=sample_data["old_log_prob"],
251251
log_prob=sample_data["log_prob"],
252252
advantages=sample_data["advantages"],

verl/trainer/ppo/core_algos.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
Optional[DictConfig | AlgoConfig], # config
4545
torch.Tensor | None, # rollout_log_probs
4646
],
47-
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
47+
tuple[torch.Tensor, dict[str, Any]],
4848
]
4949

5050
POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {}
@@ -893,7 +893,7 @@ def compute_policy_loss_vanilla(
893893
loss_agg_mode: str = "token-mean",
894894
config: Optional[DictConfig | AlgoConfig] = None,
895895
rollout_is_weights: torch.Tensor | None = None,
896-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
896+
) -> tuple[torch.Tensor, dict[str, Any]]:
897897
"""
898898
Compute the clipped policy objective and related metrics for PPO.
899899
@@ -968,7 +968,12 @@ def compute_policy_loss_vanilla(
968968

969969
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
970970

971-
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
971+
pg_metrics = {
972+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
973+
"actor/ppo_kl": ppo_kl.detach().item(),
974+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
975+
}
976+
return pg_loss, pg_metrics
972977

973978

974979
@register_policy_loss("gspo")
@@ -980,7 +985,7 @@ def compute_policy_loss_gspo(
980985
loss_agg_mode: str = "seq-mean-token-mean",
981986
config: Optional[DictConfig | ActorConfig] = None,
982987
rollout_is_weights: torch.Tensor | None = None,
983-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
988+
) -> tuple[torch.Tensor, dict[str, Any]]:
984989
"""
985990
Compute the clipped policy objective and related metrics for GSPO.
986991
@@ -1037,8 +1042,12 @@ def compute_policy_loss_gspo(
10371042
pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device)
10381043

10391044
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
1040-
1041-
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
1045+
pg_metrics = {
1046+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
1047+
"actor/ppo_kl": ppo_kl.detach().item(),
1048+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
1049+
}
1050+
return pg_loss, pg_metrics
10421051

10431052

10441053
@register_policy_loss("gpg")
@@ -1050,7 +1059,7 @@ def compute_policy_loss_gpg(
10501059
loss_agg_mode: str = "token-mean",
10511060
config: Optional[DictConfig | AlgoConfig] = None,
10521061
rollout_is_weights: torch.Tensor | None = None,
1053-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1062+
) -> tuple[torch.Tensor, dict[str, Any]]:
10541063
"""Adapted from
10551064
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
10561065
Args:
@@ -1071,7 +1080,7 @@ def compute_policy_loss_gpg(
10711080
pg_losses = pg_losses * rollout_is_weights
10721081

10731082
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
1074-
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
1083+
return pg_loss, {}
10751084

10761085

10771086
@register_policy_loss("clip_cov")
@@ -1083,7 +1092,7 @@ def compute_policy_loss_clip_cov(
10831092
loss_agg_mode: str = "token-mean",
10841093
config: Optional[DictConfig | AlgoConfig] = None,
10851094
rollout_is_weights: torch.Tensor | None = None,
1086-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1095+
) -> tuple[torch.Tensor, dict[str, Any]]:
10871096
"""
10881097
Compute the clipped policy objective and related metrics for Clip-Cov.
10891098
@@ -1170,8 +1179,11 @@ def compute_policy_loss_clip_cov(
11701179
pg_losses = pg_losses * rollout_is_weights
11711180

11721181
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
1173-
1174-
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
1182+
pg_metrics = {
1183+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
1184+
"actor/ppo_kl": ppo_kl.detach().item(),
1185+
}
1186+
return pg_loss, pg_metrics
11751187

11761188

11771189
@register_policy_loss("kl_cov")
@@ -1183,7 +1195,7 @@ def compute_policy_loss_kl_cov(
11831195
loss_agg_mode: str = "token-mean",
11841196
config: Optional[DictConfig | AlgoConfig] = None,
11851197
rollout_is_weights: torch.Tensor | None = None,
1186-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1198+
) -> tuple[torch.Tensor, dict[str, Any]]:
11871199
"""
11881200
Compute the clipped policy objective and related metrics for Clip-Cov.
11891201
@@ -1246,8 +1258,10 @@ def compute_policy_loss_kl_cov(
12461258
pg_losses = pg_losses * rollout_is_weights
12471259

12481260
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
1249-
1250-
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
1261+
pg_metrics = {
1262+
"actor/ppo_kl": ppo_kl_abs.detach().item(),
1263+
}
1264+
return pg_loss, pg_metrics
12511265

12521266

12531267
@register_policy_loss("geo_mean")
@@ -1259,7 +1273,7 @@ def compute_policy_loss_geo_mean(
12591273
loss_agg_mode: str = "token-mean",
12601274
config: Optional[DictConfig | AlgoConfig] = None,
12611275
rollout_is_weights: torch.Tensor | None = None,
1262-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1276+
) -> tuple[torch.Tensor, dict[str, Any]]:
12631277
"""
12641278
Compute the clipped policy objective and related metrics for GMPO.
12651279
@@ -1328,8 +1342,12 @@ def compute_policy_loss_geo_mean(
13281342
clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp)
13291343
pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask)
13301344
pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask)
1331-
1332-
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
1345+
pg_metrics = {
1346+
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
1347+
"actor/ppo_kl": ppo_kl.detach().item(),
1348+
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
1349+
}
1350+
return pg_loss, pg_metrics
13331351

13341352

13351353
def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"):
@@ -1672,12 +1690,14 @@ def compute_policy_loss_with_rollout_correction(
16721690
negative_approx_kl = log_prob - rollout_log_prob
16731691
kl_divergence = verl_F.masked_mean(-negative_approx_kl, effective_mask)
16741692

1675-
# No clipping in pure rollout correction mode
1676-
clip_fraction = torch.tensor(0.0)
1693+
pg_metrics = rollout_metrics
1694+
pg_metrics.update(
1695+
{
1696+
"actor/ppo_kl": kl_divergence.detach().item(),
1697+
}
1698+
)
16771699

1678-
# Return tuple matching compute_policy_loss signature: (loss, clip_fraction, kl, clip_fraction_lower)
1679-
# Note: Algorithm metrics (rollout_metrics) should be handled separately by caller
1680-
return pg_loss, clip_fraction, kl_divergence, clip_fraction
1700+
return pg_loss, pg_metrics
16811701

16821702

16831703
@register_policy_loss("rollout_correction")
@@ -1689,7 +1709,7 @@ def compute_policy_loss_rollout_correction_wrapper(
16891709
loss_agg_mode: str = "token-mean",
16901710
config: Optional[DictConfig | AlgoConfig] = None,
16911711
rollout_is_weights: torch.Tensor | None = None,
1692-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1712+
) -> tuple[torch.Tensor, dict[str, Any]]:
16931713
"""Wrapper for compute_policy_loss_with_rollout_correction to match PolicyLossFn interface.
16941714
16951715
This function is used when algorithm.rollout_correction.use_pure_rollout_correction=True.

verl/workers/actor/dp_actor.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,8 @@ def update_policy(self, data: DataProto):
452452
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
453453
policy_loss_fn = get_policy_loss_fn(loss_mode)
454454

455-
# Compute policy loss (all functions return 4 values)
456-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
455+
# Compute policy loss (any function is expected to return 2 values)
456+
pg_loss, pg_metrics = policy_loss_fn(
457457
old_log_prob=old_log_prob,
458458
log_prob=log_prob,
459459
advantages=advantages,
@@ -462,6 +462,7 @@ def update_policy(self, data: DataProto):
462462
config=self.config,
463463
rollout_is_weights=rollout_is_weights,
464464
)
465+
micro_batch_metrics.update(pg_metrics)
465466

466467
if entropy_coeff != 0:
467468
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
@@ -490,14 +491,7 @@ def update_policy(self, data: DataProto):
490491
loss = policy_loss * loss_scale_factor
491492
loss.backward()
492493

493-
micro_batch_metrics.update(
494-
{
495-
"actor/pg_loss": pg_loss.detach().item() * loss_scale_factor,
496-
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
497-
"actor/ppo_kl": ppo_kl.detach().item(),
498-
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
499-
}
500-
)
494+
micro_batch_metrics["actor/pg_loss"] = pg_loss.detach().item() * loss_scale_factor
501495
append_to_dict(metrics, micro_batch_metrics)
502496

503497
grad_norm = self._optimizer_step()

verl/workers/actor/megatron_actor.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -451,12 +451,7 @@ def loss_func(output, data, meta_info):
451451
# Extract pre-computed rollout correction weights if present
452452
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
453453
rollout_is_weights = data.get("rollout_is_weights", None)
454-
455-
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
456-
# are computed centrally in ray_trainer.py for consistency and efficiency.
457-
# This ensures metrics are computed uniformly across all batches at the trainer level
458-
# and avoids redundant computation across workers and micro-batches.
459-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
454+
pg_loss, pg_metrics = policy_loss_fn(
460455
old_log_prob=old_log_prob,
461456
log_prob=log_prob,
462457
advantages=advantages,
@@ -465,15 +460,8 @@ def loss_func(output, data, meta_info):
465460
config=self.config,
466461
rollout_is_weights=rollout_is_weights,
467462
)
468-
469-
stats.update(
470-
{
471-
"actor/pg_loss": pg_loss.detach().item(),
472-
"actor/pg_clipfrac": pg_clipfrac.detach().item(),
473-
"actor/ppo_kl": ppo_kl.detach().item(),
474-
"actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
475-
}
476-
)
463+
stats.update(pg_metrics)
464+
stats["actor/pg_loss"] = pg_loss.detach().item()
477465
policy_loss = pg_loss
478466

479467
if calculate_entropy:

verl/workers/roles/utils/losses.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,16 @@ def ppo_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None)
7373
loss_mode = config.policy_loss.get("loss_mode", "vanilla")
7474

7575
policy_loss_fn = get_policy_loss_fn(loss_mode)
76-
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
76+
pg_loss, pg_metrics = policy_loss_fn(
7777
old_log_prob=old_log_prob,
7878
log_prob=log_prob,
7979
advantages=advantages,
8080
response_mask=response_mask,
8181
loss_agg_mode=loss_agg_mode,
8282
config=config,
8383
)
84-
85-
metrics.update(
86-
{
87-
"pg_loss": pg_loss.detach().item(),
88-
"pg_clipfrac": pg_clipfrac.detach().item(),
89-
"ppo_kl": ppo_kl.detach().item(),
90-
"pg_clipfrac_lower": pg_clipfrac_lower.detach().item(),
91-
}
92-
)
84+
metrics.update(pg_metrics)
85+
metrics["actor/pg_loss"] = pg_loss.detach().item()
9386
policy_loss = pg_loss
9487

9588
# add entropy loss

0 commit comments

Comments
 (0)