Built site for gh-pages
This commit is contained in:
301
docs/rlhf.html
301
docs/rlhf.html
@@ -801,6 +801,7 @@ gtag('config', 'G-9KYCVJBNMQ', { 'anonymize_ip': true});
|
||||
<li><a href="#reward-functions" id="toc-reward-functions" class="nav-link" data-scroll-target="#reward-functions">Reward functions</a></li>
|
||||
<li><a href="#openenv-rollout-functions" id="toc-openenv-rollout-functions" class="nav-link" data-scroll-target="#openenv-rollout-functions">OpenEnv Rollout Functions</a></li>
|
||||
<li><a href="#grpo-with-dapodr.-grpo-loss" id="toc-grpo-with-dapodr.-grpo-loss" class="nav-link" data-scroll-target="#grpo-with-dapodr.-grpo-loss">GRPO with DAPO/Dr. GRPO loss</a></li>
|
||||
<li><a href="#async-grpo" id="toc-async-grpo" class="nav-link" data-scroll-target="#async-grpo">Async GRPO</a></li>
|
||||
</ul></li>
|
||||
<li><a href="#gdpo" id="toc-gdpo" class="nav-link" data-scroll-target="#gdpo">GDPO</a>
|
||||
<ul class="collapse">
|
||||
@@ -1484,6 +1485,202 @@ Note
|
||||
<span id="cb43-4"><a href="#cb43-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">max_completion_length</span><span class="kw">:</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p>For more information, see <a href="https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types">GRPO docs</a>.</p>
|
||||
</section>
|
||||
<section id="async-grpo" class="level4">
|
||||
<h4 class="anchored" data-anchor-id="async-grpo">Async GRPO</h4>
|
||||
<p>Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb44"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb44-1"><a href="#cb44-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb44-2"><a href="#cb44-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_data_producer</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Enable data producer protocol</span></span>
|
||||
<span id="cb44-3"><a href="#cb44-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_vllm</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb44-4"><a href="#cb44-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">async_prefetch</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Generate rollouts in background thread</span></span>
|
||||
<span id="cb44-5"><a href="#cb44-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">prefetch_depth</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span><span class="co"> # Number of rollouts to prefetch</span></span>
|
||||
<span id="cb44-6"><a href="#cb44-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_sync_interval</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span><span class="co"> # Sync weights to vLLM every N steps</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="callout callout-style-default callout-note callout-titled">
|
||||
<div class="callout-header d-flex align-content-center">
|
||||
<div class="callout-icon-container">
|
||||
<i class="callout-icon"></i>
|
||||
</div>
|
||||
<div class="callout-title-container flex-fill">
|
||||
Note
|
||||
</div>
|
||||
</div>
|
||||
<div class="callout-body-container callout-body">
|
||||
<p>Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by <code>vllm_importance_sampling_correction: true</code> (default when async is enabled).</p>
|
||||
</div>
|
||||
</div>
|
||||
<section id="vllm-lora-sync" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="vllm-lora-sync">vLLM LoRA Sync</h5>
|
||||
<p>By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb45"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a><span class="fu">adapter</span><span class="kw">:</span><span class="at"> lora</span></span>
|
||||
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_r</span><span class="kw">:</span><span class="at"> </span><span class="dv">32</span></span>
|
||||
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_alpha</span><span class="kw">:</span><span class="at"> </span><span class="dv">64</span></span>
|
||||
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_target_linear</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb45-5"><a href="#cb45-5" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb45-6"><a href="#cb45-6" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb45-7"><a href="#cb45-7" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_lora_sync</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Enable native LoRA sync</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p>When <code>vllm_lora_sync: true</code> is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb46"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb46-1"><a href="#cb46-1" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0 <span class="ex">axolotl</span> vllm-serve config.yaml</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p>Then start training on a separate GPU:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb47"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb47-1"><a href="#cb47-1" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>1 <span class="ex">axolotl</span> train config.yaml</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="callout callout-style-default callout-tip callout-titled">
|
||||
<div class="callout-header d-flex align-content-center">
|
||||
<div class="callout-icon-container">
|
||||
<i class="callout-icon"></i>
|
||||
</div>
|
||||
<div class="callout-title-container flex-fill">
|
||||
Tip
|
||||
</div>
|
||||
</div>
|
||||
<div class="callout-body-container callout-body">
|
||||
<p>LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.</p>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section id="streaming-partial-batch" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="streaming-partial-batch">Streaming Partial Batch</h5>
|
||||
<p>Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb48"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb48-1"><a href="#cb48-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb48-2"><a href="#cb48-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">streaming_partial_batch</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="importance-sampling-correction" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="importance-sampling-correction">Importance Sampling Correction</h5>
|
||||
<p>When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb49"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb49-1"><a href="#cb49-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb49-2"><a href="#cb49-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_importance_sampling_correction</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Enable IS correction</span></span>
|
||||
<span id="cb49-3"><a href="#cb49-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">importance_sampling_level</span><span class="kw">:</span><span class="at"> token</span><span class="co"> # 'token' or 'sequence'</span></span>
|
||||
<span id="cb49-4"><a href="#cb49-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">off_policy_mask_threshold</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span><span class="co"> # Mask sequences with IS ratio below this</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<ul>
|
||||
<li><code>importance_sampling_level: token</code> applies per-token IS ratios (recommended with Liger kernel)</li>
|
||||
<li><code>importance_sampling_level: sequence</code> applies per-sequence IS ratios</li>
|
||||
<li><code>off_policy_mask_threshold</code> masks out sequences where the IS ratio indicates they are too far off-policy</li>
|
||||
</ul>
|
||||
</section>
|
||||
<section id="replay-buffer" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="replay-buffer">Replay Buffer</h5>
|
||||
<p>The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb50"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb50-1"><a href="#cb50-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb50-2"><a href="#cb50-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">replay_buffer_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">100</span><span class="co"> # Max cached groups (0 = disabled)</span></span>
|
||||
<span id="cb50-3"><a href="#cb50-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">replay_recompute_logps</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Recompute log-probs for replayed data (recommended)</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="callout callout-style-default callout-note callout-titled">
|
||||
<div class="callout-header d-flex align-content-center">
|
||||
<div class="callout-icon-container">
|
||||
<i class="callout-icon"></i>
|
||||
</div>
|
||||
<div class="callout-title-container flex-fill">
|
||||
Note
|
||||
</div>
|
||||
</div>
|
||||
<div class="callout-body-container callout-body">
|
||||
<p>When <code>replay_recompute_logps: true</code> (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.</p>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section id="deferred-re-rolling" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="deferred-re-rolling">Deferred Re-rolling</h5>
|
||||
<p>Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb51"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb51-1"><a href="#cb51-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb51-2"><a href="#cb51-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reroll_start_fraction</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span><span class="co"> # Start re-rolling after 50% of training</span></span>
|
||||
<span id="cb51-3"><a href="#cb51-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reroll_max_groups</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span><span class="co"> # Max groups to replace per batch</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="zero-advantage-batch-skipping" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="zero-advantage-batch-skipping">Zero-Advantage Batch Skipping</h5>
|
||||
<p>When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as <code>skipped_zero_adv_batches=1</code>.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb52"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb52-1"><a href="#cb52-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb52-2"><a href="#cb52-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">skip_zero_advantage_batches</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # default</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="parallel-reward-workers" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="parallel-reward-workers">Parallel Reward Workers</h5>
|
||||
<p>Reward functions that use <code>signal.alarm()</code> (e.g., <code>math_verify</code>) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb53"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb53-1"><a href="#cb53-1" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb53-2"><a href="#cb53-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_num_workers</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span><span class="co"> # Number of subprocess workers (1 = no parallelism)</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="full-async-grpo-example" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="full-async-grpo-example">Full Async GRPO Example</h5>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb54"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb54-1"><a href="#cb54-1" aria-hidden="true" tabindex="-1"></a><span class="fu">base_model</span><span class="kw">:</span><span class="at"> Qwen/Qwen2.5-1.5B-Instruct</span></span>
|
||||
<span id="cb54-2"><a href="#cb54-2" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb54-3"><a href="#cb54-3" aria-hidden="true" tabindex="-1"></a><span class="fu">vllm</span><span class="kw">:</span></span>
|
||||
<span id="cb54-4"><a href="#cb54-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">host</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.0.0.0</span></span>
|
||||
<span id="cb54-5"><a href="#cb54-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">port</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
|
||||
<span id="cb54-6"><a href="#cb54-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">gpu_memory_utilization</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.35</span></span>
|
||||
<span id="cb54-7"><a href="#cb54-7" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">dtype</span><span class="kw">:</span><span class="at"> auto</span></span>
|
||||
<span id="cb54-8"><a href="#cb54-8" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb54-9"><a href="#cb54-9" aria-hidden="true" tabindex="-1"></a><span class="fu">adapter</span><span class="kw">:</span><span class="at"> lora</span></span>
|
||||
<span id="cb54-10"><a href="#cb54-10" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_r</span><span class="kw">:</span><span class="at"> </span><span class="dv">32</span></span>
|
||||
<span id="cb54-11"><a href="#cb54-11" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_alpha</span><span class="kw">:</span><span class="at"> </span><span class="dv">64</span></span>
|
||||
<span id="cb54-12"><a href="#cb54-12" aria-hidden="true" tabindex="-1"></a><span class="fu">lora_target_linear</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-13"><a href="#cb54-13" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb54-14"><a href="#cb54-14" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> grpo</span></span>
|
||||
<span id="cb54-15"><a href="#cb54-15" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb54-16"><a href="#cb54-16" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_data_producer</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-17"><a href="#cb54-17" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_vllm</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-18"><a href="#cb54-18" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">async_prefetch</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-19"><a href="#cb54-19" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">prefetch_depth</span><span class="kw">:</span><span class="at"> </span><span class="dv">1</span></span>
|
||||
<span id="cb54-20"><a href="#cb54-20" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_sync_interval</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
|
||||
<span id="cb54-21"><a href="#cb54-21" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_lora_sync</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-22"><a href="#cb54-22" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">streaming_partial_batch</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-23"><a href="#cb54-23" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">vllm_importance_sampling_correction</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-24"><a href="#cb54-24" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">off_policy_mask_threshold</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span></span>
|
||||
<span id="cb54-25"><a href="#cb54-25" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">importance_sampling_level</span><span class="kw">:</span><span class="at"> token</span></span>
|
||||
<span id="cb54-26"><a href="#cb54-26" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">num_generations</span><span class="kw">:</span><span class="at"> </span><span class="dv">8</span></span>
|
||||
<span id="cb54-27"><a href="#cb54-27" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">max_completion_length</span><span class="kw">:</span><span class="at"> </span><span class="dv">512</span></span>
|
||||
<span id="cb54-28"><a href="#cb54-28" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_funcs</span><span class="kw">:</span></span>
|
||||
<span id="cb54-29"><a href="#cb54-29" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> rewards.accuracy_reward</span></span>
|
||||
<span id="cb54-30"><a href="#cb54-30" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reroll_start_fraction</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span></span>
|
||||
<span id="cb54-31"><a href="#cb54-31" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">replay_buffer_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">100</span></span>
|
||||
<span id="cb54-32"><a href="#cb54-32" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_num_workers</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span></span>
|
||||
<span id="cb54-33"><a href="#cb54-33" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">skip_zero_advantage_batches</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-34"><a href="#cb54-34" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb54-35"><a href="#cb54-35" aria-hidden="true" tabindex="-1"></a><span class="fu">datasets</span><span class="kw">:</span></span>
|
||||
<span id="cb54-36"><a href="#cb54-36" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> </span><span class="fu">path</span><span class="kw">:</span><span class="at"> AI-MO/NuminaMath-TIR</span></span>
|
||||
<span id="cb54-37"><a href="#cb54-37" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">type</span><span class="kw">:</span><span class="at"> rewards.prompt_transform</span></span>
|
||||
<span id="cb54-38"><a href="#cb54-38" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">split</span><span class="kw">:</span><span class="at"> train</span></span>
|
||||
<span id="cb54-39"><a href="#cb54-39" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb54-40"><a href="#cb54-40" aria-hidden="true" tabindex="-1"></a><span class="fu">gradient_accumulation_steps</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span></span>
|
||||
<span id="cb54-41"><a href="#cb54-41" aria-hidden="true" tabindex="-1"></a><span class="fu">micro_batch_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
|
||||
<span id="cb54-42"><a href="#cb54-42" aria-hidden="true" tabindex="-1"></a><span class="fu">max_steps</span><span class="kw">:</span><span class="at"> </span><span class="dv">500</span></span>
|
||||
<span id="cb54-43"><a href="#cb54-43" aria-hidden="true" tabindex="-1"></a><span class="fu">learning_rate</span><span class="kw">:</span><span class="at"> </span><span class="fl">1e-5</span></span>
|
||||
<span id="cb54-44"><a href="#cb54-44" aria-hidden="true" tabindex="-1"></a><span class="fu">bf16</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb54-45"><a href="#cb54-45" aria-hidden="true" tabindex="-1"></a><span class="fu">gradient_checkpointing</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb55"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb55-1"><a href="#cb55-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Terminal 1: Start vLLM on GPU 0</span></span>
|
||||
<span id="cb55-2"><a href="#cb55-2" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0 <span class="ex">axolotl</span> vllm-serve config.yaml</span>
|
||||
<span id="cb55-3"><a href="#cb55-3" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb55-4"><a href="#cb55-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Terminal 2: Train on GPU 1</span></span>
|
||||
<span id="cb55-5"><a href="#cb55-5" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>1 <span class="ex">axolotl</span> train config.yaml</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="multi-gpu-async-grpo" class="level5">
|
||||
<h5 class="anchored" data-anchor-id="multi-gpu-async-grpo">Multi-GPU Async GRPO</h5>
|
||||
<p>Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.</p>
|
||||
<p><strong>FSDP:</strong></p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb56"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb56-1"><a href="#cb56-1" aria-hidden="true" tabindex="-1"></a><span class="fu">fsdp</span><span class="kw">:</span></span>
|
||||
<span id="cb56-2"><a href="#cb56-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> full_shard</span></span>
|
||||
<span id="cb56-3"><a href="#cb56-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> auto_wrap</span></span>
|
||||
<span id="cb56-4"><a href="#cb56-4" aria-hidden="true" tabindex="-1"></a><span class="fu">fsdp_config</span><span class="kw">:</span></span>
|
||||
<span id="cb56-5"><a href="#cb56-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">fsdp_transformer_layer_cls_to_wrap</span><span class="kw">:</span><span class="at"> Qwen2DecoderLayer</span></span>
|
||||
<span id="cb56-6"><a href="#cb56-6" aria-hidden="true" tabindex="-1"></a><span class="fu">gradient_checkpointing_kwargs</span><span class="kw">:</span></span>
|
||||
<span id="cb56-7"><a href="#cb56-7" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_reentrant</span><span class="kw">:</span><span class="at"> </span><span class="ch">false</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p><strong>DeepSpeed ZeRO-3:</strong></p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb57"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb57-1"><a href="#cb57-1" aria-hidden="true" tabindex="-1"></a><span class="fu">deepspeed</span><span class="kw">:</span><span class="at"> deepspeed_configs/zero3_bf16.json</span></span>
|
||||
<span id="cb57-2"><a href="#cb57-2" aria-hidden="true" tabindex="-1"></a><span class="fu">gradient_checkpointing_kwargs</span><span class="kw">:</span></span>
|
||||
<span id="cb57-3"><a href="#cb57-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_reentrant</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span><span class="co"> # Required for ZeRO-3</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb58"><pre class="sourceCode bash code-with-copy"><code class="sourceCode bash"><span id="cb58-1"><a href="#cb58-1" aria-hidden="true" tabindex="-1"></a><span class="co"># Terminal 1: Start vLLM on GPU 0</span></span>
|
||||
<span id="cb58-2"><a href="#cb58-2" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0 <span class="ex">axolotl</span> vllm-serve config.yaml</span>
|
||||
<span id="cb58-3"><a href="#cb58-3" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb58-4"><a href="#cb58-4" aria-hidden="true" tabindex="-1"></a><span class="co"># Terminal 2: Train on GPUs 0,1</span></span>
|
||||
<span id="cb58-5"><a href="#cb58-5" aria-hidden="true" tabindex="-1"></a><span class="va">CUDA_VISIBLE_DEVICES</span><span class="op">=</span>0,1 <span class="ex">accelerate</span> launch <span class="at">--num_processes</span> 2 <span class="at">-m</span> axolotl.cli.train config.yaml</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="callout callout-style-default callout-important callout-titled">
|
||||
<div class="callout-header d-flex align-content-center">
|
||||
<div class="callout-icon-container">
|
||||
<i class="callout-icon"></i>
|
||||
</div>
|
||||
<div class="callout-title-container flex-fill">
|
||||
Important
|
||||
</div>
|
||||
</div>
|
||||
<div class="callout-body-container callout-body">
|
||||
<p>With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.</p>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</section>
|
||||
</section>
|
||||
<section id="gdpo" class="level3">
|
||||
<h3 class="anchored" data-anchor-id="gdpo">GDPO</h3>
|
||||
@@ -1503,35 +1700,35 @@ Tip
|
||||
</div>
|
||||
<p>Paper: <a href="https://arxiv.org/pdf/2501.05242">https://arxiv.org/pdf/2501.05242</a></p>
|
||||
<p>GDPO uses TRL’s native <code>multi_objective_aggregation</code> parameter under the hood. When you set <code>rl: gdpo</code>, axolotl automatically configures TRL to use <code>normalize_then_sum</code> aggregation.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb44"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb44-1"><a href="#cb44-1" aria-hidden="true" tabindex="-1"></a><span class="fu">base_model</span><span class="kw">:</span><span class="at"> Qwen/Qwen2.5-1.5B-Instruct</span></span>
|
||||
<span id="cb44-2"><a href="#cb44-2" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb44-3"><a href="#cb44-3" aria-hidden="true" tabindex="-1"></a><span class="fu">vllm</span><span class="kw">:</span></span>
|
||||
<span id="cb44-4"><a href="#cb44-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">host</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.0.0.0</span></span>
|
||||
<span id="cb44-5"><a href="#cb44-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">port</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
|
||||
<span id="cb44-6"><a href="#cb44-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">tensor_parallel_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
|
||||
<span id="cb44-7"><a href="#cb44-7" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">gpu_memory_utilization</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.85</span></span>
|
||||
<span id="cb44-8"><a href="#cb44-8" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb44-9"><a href="#cb44-9" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> gdpo</span></span>
|
||||
<span id="cb44-10"><a href="#cb44-10" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb44-11"><a href="#cb44-11" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb44-12"><a href="#cb44-12" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">beta</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.001</span></span>
|
||||
<span id="cb44-13"><a href="#cb44-13" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">max_completion_length</span><span class="kw">:</span><span class="at"> </span><span class="dv">256</span></span>
|
||||
<span id="cb44-14"><a href="#cb44-14" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_vllm</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb44-15"><a href="#cb44-15" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">num_generations</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span></span>
|
||||
<span id="cb44-16"><a href="#cb44-16" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_funcs</span><span class="kw">:</span></span>
|
||||
<span id="cb44-17"><a href="#cb44-17" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> rewards.format_reward</span></span>
|
||||
<span id="cb44-18"><a href="#cb44-18" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> rewards.correctness_reward</span></span>
|
||||
<span id="cb44-19"><a href="#cb44-19" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_weights</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="fl">1.0</span><span class="kw">,</span><span class="at"> </span><span class="fl">2.0</span><span class="kw">]</span></span>
|
||||
<span id="cb44-20"><a href="#cb44-20" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb44-21"><a href="#cb44-21" aria-hidden="true" tabindex="-1"></a><span class="fu">datasets</span><span class="kw">:</span></span>
|
||||
<span id="cb44-22"><a href="#cb44-22" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> </span><span class="fu">path</span><span class="kw">:</span><span class="at"> openai/gsm8k</span></span>
|
||||
<span id="cb44-23"><a href="#cb44-23" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> main</span></span>
|
||||
<span id="cb44-24"><a href="#cb44-24" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">type</span><span class="kw">:</span><span class="at"> rewards.oai_gsm8k_transform</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb59"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb59-1"><a href="#cb59-1" aria-hidden="true" tabindex="-1"></a><span class="fu">base_model</span><span class="kw">:</span><span class="at"> Qwen/Qwen2.5-1.5B-Instruct</span></span>
|
||||
<span id="cb59-2"><a href="#cb59-2" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb59-3"><a href="#cb59-3" aria-hidden="true" tabindex="-1"></a><span class="fu">vllm</span><span class="kw">:</span></span>
|
||||
<span id="cb59-4"><a href="#cb59-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">host</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.0.0.0</span></span>
|
||||
<span id="cb59-5"><a href="#cb59-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">port</span><span class="kw">:</span><span class="at"> </span><span class="dv">8000</span></span>
|
||||
<span id="cb59-6"><a href="#cb59-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">tensor_parallel_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span>
|
||||
<span id="cb59-7"><a href="#cb59-7" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">gpu_memory_utilization</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.85</span></span>
|
||||
<span id="cb59-8"><a href="#cb59-8" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb59-9"><a href="#cb59-9" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> gdpo</span></span>
|
||||
<span id="cb59-10"><a href="#cb59-10" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb59-11"><a href="#cb59-11" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb59-12"><a href="#cb59-12" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">beta</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.001</span></span>
|
||||
<span id="cb59-13"><a href="#cb59-13" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">max_completion_length</span><span class="kw">:</span><span class="at"> </span><span class="dv">256</span></span>
|
||||
<span id="cb59-14"><a href="#cb59-14" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">use_vllm</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span>
|
||||
<span id="cb59-15"><a href="#cb59-15" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">num_generations</span><span class="kw">:</span><span class="at"> </span><span class="dv">4</span></span>
|
||||
<span id="cb59-16"><a href="#cb59-16" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_funcs</span><span class="kw">:</span></span>
|
||||
<span id="cb59-17"><a href="#cb59-17" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> rewards.format_reward</span></span>
|
||||
<span id="cb59-18"><a href="#cb59-18" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> rewards.correctness_reward</span></span>
|
||||
<span id="cb59-19"><a href="#cb59-19" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">reward_weights</span><span class="kw">:</span><span class="at"> </span><span class="kw">[</span><span class="fl">1.0</span><span class="kw">,</span><span class="at"> </span><span class="fl">2.0</span><span class="kw">]</span></span>
|
||||
<span id="cb59-20"><a href="#cb59-20" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb59-21"><a href="#cb59-21" aria-hidden="true" tabindex="-1"></a><span class="fu">datasets</span><span class="kw">:</span></span>
|
||||
<span id="cb59-22"><a href="#cb59-22" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> </span><span class="fu">path</span><span class="kw">:</span><span class="at"> openai/gsm8k</span></span>
|
||||
<span id="cb59-23"><a href="#cb59-23" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">name</span><span class="kw">:</span><span class="at"> main</span></span>
|
||||
<span id="cb59-24"><a href="#cb59-24" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">type</span><span class="kw">:</span><span class="at"> rewards.oai_gsm8k_transform</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p>You can also use GRPO with explicit aggregation control:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb45"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb45-1"><a href="#cb45-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> grpo</span></span>
|
||||
<span id="cb45-2"><a href="#cb45-2" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb45-3"><a href="#cb45-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">multi_objective_aggregation</span><span class="kw">:</span><span class="at"> normalize_then_sum</span><span class="co"> # GDPO behavior</span></span>
|
||||
<span id="cb45-4"><a href="#cb45-4" aria-hidden="true" tabindex="-1"></a><span class="co"> # or: sum_then_normalize # Default GRPO behavior</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb60"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb60-1"><a href="#cb60-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> grpo</span></span>
|
||||
<span id="cb60-2"><a href="#cb60-2" aria-hidden="true" tabindex="-1"></a><span class="fu">trl</span><span class="kw">:</span></span>
|
||||
<span id="cb60-3"><a href="#cb60-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">multi_objective_aggregation</span><span class="kw">:</span><span class="at"> normalize_then_sum</span><span class="co"> # GDPO behavior</span></span>
|
||||
<span id="cb60-4"><a href="#cb60-4" aria-hidden="true" tabindex="-1"></a><span class="co"> # or: sum_then_normalize # Default GRPO behavior</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<section id="gdpo-vs-grpo" class="level4">
|
||||
<h4 class="anchored" data-anchor-id="gdpo-vs-grpo">GDPO vs GRPO</h4>
|
||||
<table class="caption-top table">
|
||||
@@ -1579,47 +1776,47 @@ Tip
|
||||
<section id="reward-functions-1" class="level4">
|
||||
<h4 class="anchored" data-anchor-id="reward-functions-1">Reward Functions</h4>
|
||||
<p>GDPO uses the same reward function format as GRPO:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb47"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb47-1"><a href="#cb47-1" aria-hidden="true" tabindex="-1"></a><span class="co"># rewards.py</span></span>
|
||||
<span id="cb47-2"><a href="#cb47-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> format_reward(completions, <span class="op">**</span>kwargs) <span class="op">-></span> <span class="bu">list</span>[<span class="bu">float</span>]:</span>
|
||||
<span id="cb47-3"><a href="#cb47-3" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> [<span class="fl">1.0</span> <span class="cf">if</span> <span class="bu">len</span>(c) <span class="op">></span> <span class="dv">10</span> <span class="cf">else</span> <span class="fl">0.0</span> <span class="cf">for</span> c <span class="kw">in</span> completions]</span>
|
||||
<span id="cb47-4"><a href="#cb47-4" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb47-5"><a href="#cb47-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> correctness_reward(completions, answers, <span class="op">**</span>kwargs) <span class="op">-></span> <span class="bu">list</span>[<span class="bu">float</span>]:</span>
|
||||
<span id="cb47-6"><a href="#cb47-6" aria-hidden="true" tabindex="-1"></a> rewards <span class="op">=</span> []</span>
|
||||
<span id="cb47-7"><a href="#cb47-7" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> completion, answer <span class="kw">in</span> <span class="bu">zip</span>(completions, answers):</span>
|
||||
<span id="cb47-8"><a href="#cb47-8" aria-hidden="true" tabindex="-1"></a> <span class="co"># Your scoring logic here</span></span>
|
||||
<span id="cb47-9"><a href="#cb47-9" aria-hidden="true" tabindex="-1"></a> rewards.append(score)</span>
|
||||
<span id="cb47-10"><a href="#cb47-10" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> rewards</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb62"><pre class="sourceCode python code-with-copy"><code class="sourceCode python"><span id="cb62-1"><a href="#cb62-1" aria-hidden="true" tabindex="-1"></a><span class="co"># rewards.py</span></span>
|
||||
<span id="cb62-2"><a href="#cb62-2" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> format_reward(completions, <span class="op">**</span>kwargs) <span class="op">-></span> <span class="bu">list</span>[<span class="bu">float</span>]:</span>
|
||||
<span id="cb62-3"><a href="#cb62-3" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> [<span class="fl">1.0</span> <span class="cf">if</span> <span class="bu">len</span>(c) <span class="op">></span> <span class="dv">10</span> <span class="cf">else</span> <span class="fl">0.0</span> <span class="cf">for</span> c <span class="kw">in</span> completions]</span>
|
||||
<span id="cb62-4"><a href="#cb62-4" aria-hidden="true" tabindex="-1"></a></span>
|
||||
<span id="cb62-5"><a href="#cb62-5" aria-hidden="true" tabindex="-1"></a><span class="kw">def</span> correctness_reward(completions, answers, <span class="op">**</span>kwargs) <span class="op">-></span> <span class="bu">list</span>[<span class="bu">float</span>]:</span>
|
||||
<span id="cb62-6"><a href="#cb62-6" aria-hidden="true" tabindex="-1"></a> rewards <span class="op">=</span> []</span>
|
||||
<span id="cb62-7"><a href="#cb62-7" aria-hidden="true" tabindex="-1"></a> <span class="cf">for</span> completion, answer <span class="kw">in</span> <span class="bu">zip</span>(completions, answers):</span>
|
||||
<span id="cb62-8"><a href="#cb62-8" aria-hidden="true" tabindex="-1"></a> <span class="co"># Your scoring logic here</span></span>
|
||||
<span id="cb62-9"><a href="#cb62-9" aria-hidden="true" tabindex="-1"></a> rewards.append(score)</span>
|
||||
<span id="cb62-10"><a href="#cb62-10" aria-hidden="true" tabindex="-1"></a> <span class="cf">return</span> rewards</span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="sequence-parallelism" class="level4">
|
||||
<h4 class="anchored" data-anchor-id="sequence-parallelism">Sequence Parallelism</h4>
|
||||
<p>GDPO supports sequence parallelism for long-context training:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb48"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb48-1"><a href="#cb48-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> gdpo</span></span>
|
||||
<span id="cb48-2"><a href="#cb48-2" aria-hidden="true" tabindex="-1"></a><span class="fu">context_parallel_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb63"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb63-1"><a href="#cb63-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> gdpo</span></span>
|
||||
<span id="cb63-2"><a href="#cb63-2" aria-hidden="true" tabindex="-1"></a><span class="fu">context_parallel_size</span><span class="kw">:</span><span class="at"> </span><span class="dv">2</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
</section>
|
||||
<section id="simpo" class="level3">
|
||||
<h3 class="anchored" data-anchor-id="simpo">SimPO</h3>
|
||||
<p>SimPO uses <a href="https://huggingface.co/docs/trl/main/en/cpo_trainer">CPOTrainer</a> but with alternative loss function.</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb49"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb49-1"><a href="#cb49-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> simpo</span></span>
|
||||
<span id="cb49-2"><a href="#cb49-2" aria-hidden="true" tabindex="-1"></a><span class="fu">rl_beta</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.1</span><span class="co"> # default in CPOTrainer</span></span>
|
||||
<span id="cb49-3"><a href="#cb49-3" aria-hidden="true" tabindex="-1"></a><span class="fu">cpo_alpha</span><span class="kw">:</span><span class="at"> </span><span class="fl">1.0</span><span class="co"> # default in CPOTrainer</span></span>
|
||||
<span id="cb49-4"><a href="#cb49-4" aria-hidden="true" tabindex="-1"></a><span class="fu">simpo_gamma</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span><span class="co"> # default in CPOTrainer</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb64"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb64-1"><a href="#cb64-1" aria-hidden="true" tabindex="-1"></a><span class="fu">rl</span><span class="kw">:</span><span class="at"> simpo</span></span>
|
||||
<span id="cb64-2"><a href="#cb64-2" aria-hidden="true" tabindex="-1"></a><span class="fu">rl_beta</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.1</span><span class="co"> # default in CPOTrainer</span></span>
|
||||
<span id="cb64-3"><a href="#cb64-3" aria-hidden="true" tabindex="-1"></a><span class="fu">cpo_alpha</span><span class="kw">:</span><span class="at"> </span><span class="fl">1.0</span><span class="co"> # default in CPOTrainer</span></span>
|
||||
<span id="cb64-4"><a href="#cb64-4" aria-hidden="true" tabindex="-1"></a><span class="fu">simpo_gamma</span><span class="kw">:</span><span class="at"> </span><span class="fl">0.5</span><span class="co"> # default in CPOTrainer</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<p>This method uses the same dataset format as <a href="#dpo">DPO</a>.</p>
|
||||
</section>
|
||||
<section id="using-local-dataset-files" class="level3">
|
||||
<h3 class="anchored" data-anchor-id="using-local-dataset-files">Using local dataset files</h3>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb50"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb50-1"><a href="#cb50-1" aria-hidden="true" tabindex="-1"></a><span class="fu">datasets</span><span class="kw">:</span></span>
|
||||
<span id="cb50-2"><a href="#cb50-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> </span><span class="fu">ds_type</span><span class="kw">:</span><span class="at"> json</span></span>
|
||||
<span id="cb50-3"><a href="#cb50-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">data_files</span><span class="kw">:</span></span>
|
||||
<span id="cb50-4"><a href="#cb50-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> orca_rlhf.jsonl</span></span>
|
||||
<span id="cb50-5"><a href="#cb50-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">split</span><span class="kw">:</span><span class="at"> train</span></span>
|
||||
<span id="cb50-6"><a href="#cb50-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">type</span><span class="kw">:</span><span class="at"> chatml.intel</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb65"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb65-1"><a href="#cb65-1" aria-hidden="true" tabindex="-1"></a><span class="fu">datasets</span><span class="kw">:</span></span>
|
||||
<span id="cb65-2"><a href="#cb65-2" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> </span><span class="fu">ds_type</span><span class="kw">:</span><span class="at"> json</span></span>
|
||||
<span id="cb65-3"><a href="#cb65-3" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">data_files</span><span class="kw">:</span></span>
|
||||
<span id="cb65-4"><a href="#cb65-4" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="kw">-</span><span class="at"> orca_rlhf.jsonl</span></span>
|
||||
<span id="cb65-5"><a href="#cb65-5" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">split</span><span class="kw">:</span><span class="at"> train</span></span>
|
||||
<span id="cb65-6"><a href="#cb65-6" aria-hidden="true" tabindex="-1"></a><span class="at"> </span><span class="fu">type</span><span class="kw">:</span><span class="at"> chatml.intel</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
</section>
|
||||
<section id="trl-auto-unwrapping-for-peft" class="level3">
|
||||
<h3 class="anchored" data-anchor-id="trl-auto-unwrapping-for-peft">TRL auto-unwrapping for PEFT</h3>
|
||||
<p>TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:</p>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb51"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb51-1"><a href="#cb51-1" aria-hidden="true" tabindex="-1"></a><span class="co"># load ref model when adapter training.</span></span>
|
||||
<span id="cb51-2"><a href="#cb51-2" aria-hidden="true" tabindex="-1"></a><span class="fu">rl_adapter_ref_model</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
<div class="code-copy-outer-scaffold"><div class="sourceCode" id="cb66"><pre class="sourceCode yaml code-with-copy"><code class="sourceCode yaml"><span id="cb66-1"><a href="#cb66-1" aria-hidden="true" tabindex="-1"></a><span class="co"># load ref model when adapter training.</span></span>
|
||||
<span id="cb66-2"><a href="#cb66-2" aria-hidden="true" tabindex="-1"></a><span class="fu">rl_adapter_ref_model</span><span class="kw">:</span><span class="at"> </span><span class="ch">true</span></span></code></pre></div><button title="Copy to Clipboard" class="code-copy-button"><i class="bi"></i></button></div>
|
||||
|
||||
|
||||
</section>
|
||||
|
||||
Reference in New Issue
Block a user