diff --git a/.nojekyll b/.nojekyll index 9161c2a04..e27f3e39e 100644 --- a/.nojekyll +++ b/.nojekyll @@ -1 +1 @@ -465fe8f4 \ No newline at end of file +754b3ca3 \ No newline at end of file diff --git a/docs/api/core.trainers.grpo.trainer.html b/docs/api/core.trainers.grpo.trainer.html index 0ba40da85..2bbd737fc 100644 --- a/docs/api/core.trainers.grpo.trainer.html +++ b/docs/api/core.trainers.grpo.trainer.html @@ -758,6 +758,7 @@ gtag('config', 'G-9KYCVJBNMQ', { 'anonymize_ip': true});
core.trainers.grpo.trainer.AxolotlAsyncGRPOTrainer(*args, **kwargs)Extend AsyncGRPOTrainer with axolotl helpers
+core.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer(
- model,
- reward_funcs,
- args=None,
- train_dataset=None,
- eval_dataset=None,
- processing_class=None,
- reward_processing_classes=None,
- callbacks=None,
- optimizers=(None, None),
- peft_config=None,
- optimizer_cls_and_kwargs=None,
-)core.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer(
+ model,
+ reward_funcs,
+ args=None,
+ train_dataset=None,
+ eval_dataset=None,
+ processing_class=None,
+ reward_processing_classes=None,
+ callbacks=None,
+ optimizers=(None, None),
+ peft_config=None,
+ optimizer_cls_and_kwargs=None,
+)Extend the base GRPOTrainer for sequence parallelism handling
core.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer.get_train_dataloader(
-)core.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer.get_train_dataloader(
+)Get dataloader for training
core.trainers.grpo.trainer.AxolotlGRPOTrainer(*args, **kwargs)core.trainers.grpo.trainer.AxolotlGRPOTrainer(*args, **kwargs)Extend the base GRPOTrainer for axolotl helpers
diff --git a/docs/api/index.html b/docs/api/index.html index fcce4df33..561f26605 100644 --- a/docs/api/index.html +++ b/docs/api/index.html @@ -1144,7 +1144,7 @@ gtag('config', 'G-9KYCVJBNMQ', { 'anonymize_ip': true});bitsandbytes integration.bitsandbytes and FP8 integration.| - | QuantState | None | +QuantState | torch.Tensor | None | None if not available. |
| W_quant | -QuantState | None | +QuantState | torch.Tensor | None | Quantization state for W | required |
| dequantize | Fast NF4 dequantization using bitsandbytes CUDA kernels. |
+
| dequantize_fp8 | +Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv. | +
| quant_state | -QuantState | list | None | +QuantState | list | torch.Tensor | None | Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged. |
None |
| Name | +Type | +Description | +Default | +
|---|---|---|---|
| W | +torch.Tensor | +FP8 weight tensor [out_features, in_features] in float8_e4m3fn. | +required | +
| scale_inv | +torch.Tensor | +Per-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar. | +required | +
| dtype | +torch.dtype | +Output dtype (default bf16). | +torch.bfloat16 |
+
| Name | +Type | +Description | +
|---|---|---|
| + | torch.Tensor | +Dequantized tensor in the specified dtype. | +
For more information, see GRPO docs.
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
+trl:
+ use_data_producer: true # Enable data producer protocol
+ use_vllm: true
+ async_prefetch: true # Generate rollouts in background thread
+ prefetch_depth: 1 # Number of rollouts to prefetch
+ vllm_sync_interval: 2 # Sync weights to vLLM every N stepsBecause the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by vllm_importance_sampling_correction: true (default when async is enabled).
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
+adapter: lora
+lora_r: 32
+lora_alpha: 64
+lora_target_linear: true
+
+trl:
+ vllm_lora_sync: true # Enable native LoRA syncWhen vllm_lora_sync: true is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yamlThen start training on a separate GPU:
+CUDA_VISIBLE_DEVICES=1 axolotl train config.yamlLoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
+Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
+trl:
+ streaming_partial_batch: trueWhen using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
+trl:
+ vllm_importance_sampling_correction: true # Enable IS correction
+ importance_sampling_level: token # 'token' or 'sequence'
+ off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below thisimportance_sampling_level: token applies per-token IS ratios (recommended with Liger kernel)importance_sampling_level: sequence applies per-sequence IS ratiosoff_policy_mask_threshold masks out sequences where the IS ratio indicates they are too far off-policyThe replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
+trl:
+ replay_buffer_size: 100 # Max cached groups (0 = disabled)
+ replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)When replay_recompute_logps: true (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
+trl:
+ reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
+ reroll_max_groups: 1 # Max groups to replace per batchWhen all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as skipped_zero_adv_batches=1.
trl:
+ skip_zero_advantage_batches: true # defaultReward functions that use signal.alarm() (e.g., math_verify) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
trl:
+ reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)base_model: Qwen/Qwen2.5-1.5B-Instruct
+
+vllm:
+ host: 0.0.0.0
+ port: 8000
+ gpu_memory_utilization: 0.35
+ dtype: auto
+
+adapter: lora
+lora_r: 32
+lora_alpha: 64
+lora_target_linear: true
+
+rl: grpo
+trl:
+ use_data_producer: true
+ use_vllm: true
+ async_prefetch: true
+ prefetch_depth: 1
+ vllm_sync_interval: 2
+ vllm_lora_sync: true
+ streaming_partial_batch: true
+ vllm_importance_sampling_correction: true
+ off_policy_mask_threshold: 0.5
+ importance_sampling_level: token
+ num_generations: 8
+ max_completion_length: 512
+ reward_funcs:
+ - rewards.accuracy_reward
+ reroll_start_fraction: 0.5
+ replay_buffer_size: 100
+ reward_num_workers: 4
+ skip_zero_advantage_batches: true
+
+datasets:
+ - path: AI-MO/NuminaMath-TIR
+ type: rewards.prompt_transform
+ split: train
+
+gradient_accumulation_steps: 4
+micro_batch_size: 2
+max_steps: 500
+learning_rate: 1e-5
+bf16: true
+gradient_checkpointing: true# Terminal 1: Start vLLM on GPU 0
+CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
+
+# Terminal 2: Train on GPU 1
+CUDA_VISIBLE_DEVICES=1 axolotl train config.yamlAsync GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
+FSDP:
+fsdp:
+ - full_shard
+ - auto_wrap
+fsdp_config:
+ fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
+gradient_checkpointing_kwargs:
+ use_reentrant: falseDeepSpeed ZeRO-3:
+deepspeed: deepspeed_configs/zero3_bf16.json
+gradient_checkpointing_kwargs:
+ use_reentrant: true # Required for ZeRO-3# Terminal 1: Start vLLM on GPU 0
+CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
+
+# Terminal 2: Train on GPUs 0,1
+CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yamlWith multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
+Paper: https://arxiv.org/pdf/2501.05242
GDPO uses TRL’s native multi_objective_aggregation parameter under the hood. When you set rl: gdpo, axolotl automatically configures TRL to use normalize_then_sum aggregation.
base_model: Qwen/Qwen2.5-1.5B-Instruct
-
-vllm:
- host: 0.0.0.0
- port: 8000
- tensor_parallel_size: 2
- gpu_memory_utilization: 0.85
-
-rl: gdpo
-
-trl:
- beta: 0.001
- max_completion_length: 256
- use_vllm: true
- num_generations: 4
- reward_funcs:
- - rewards.format_reward
- - rewards.correctness_reward
- reward_weights: [1.0, 2.0]
-
-datasets:
- - path: openai/gsm8k
- name: main
- type: rewards.oai_gsm8k_transformbase_model: Qwen/Qwen2.5-1.5B-Instruct
+
+vllm:
+ host: 0.0.0.0
+ port: 8000
+ tensor_parallel_size: 2
+ gpu_memory_utilization: 0.85
+
+rl: gdpo
+
+trl:
+ beta: 0.001
+ max_completion_length: 256
+ use_vllm: true
+ num_generations: 4
+ reward_funcs:
+ - rewards.format_reward
+ - rewards.correctness_reward
+ reward_weights: [1.0, 2.0]
+
+datasets:
+ - path: openai/gsm8k
+ name: main
+ type: rewards.oai_gsm8k_transformYou can also use GRPO with explicit aggregation control:
-rl: grpo
-trl:
- multi_objective_aggregation: normalize_then_sum # GDPO behavior
- # or: sum_then_normalize # Default GRPO behaviorrl: grpo
+trl:
+ multi_objective_aggregation: normalize_then_sum # GDPO behavior
+ # or: sum_then_normalize # Default GRPO behavior