Skip to content

Commit 418f964

Browse files
szrleetongyx361
andauthored
[BREAKING][algo] feat: Rollout Correction for General Off-Policy Problems (volcengine#3984)
## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <[email protected]>
1 parent 4bf4bd3 commit 418f964

28 files changed

+3659
-1919
lines changed

docs/advance/fully_async.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
166166

167167
During the training process, we observed that metrics and response lengths may become unstable in the later
168168
stages of training. To mitigate this issue, we can use
169-
the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html)
170-
technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using
169+
the [Rollout Correction](https://verl.readthedocs.io/en/latest/advance/rollout_corr.html)
170+
technique for importance sampling and rejection sampling. To utilize Rollout Correction, we need to compute log_prob using
171171
the training engine, which requires enabling this switch.
172-
Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d
172+
Additionally, when compute_prox_log_prob and Rollout Correction are enabled under mode d
173173
(async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`.
174174

175175
### Supported Modes

0 commit comments

Comments
 (0)