--- title: "RLHF (Beta)" description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." back-to-top-navigation: true toc: true toc-expand: 2 toc-depth: 4 --- ## Overview Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback. Various methods include, but not limited to: - [Direct Preference Optimization (DPO)](#dpo) - [Identity Preference Optimization (IPO)](#ipo) - [Kahneman-Tversky Optimization (KTO)](#kto) - [Odds Ratio Preference Optimization (ORPO)](#orpo) - [Group Relative Policy Optimization (GRPO)](#grpo) — see also the [GRPO deep dive](grpo.qmd) for async features, custom rewards, and scaling - [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo) - [Energy-Based Fine-Tuning (EBFT)](#ebft) — see also the [EBFT guide](ebft.qmd) for detailed mode comparisons and configuration - [NeMo Gym Integration](#nemo-gym-integration) For help choosing between these methods, see [Choosing a Fine-Tuning Method](choosing_method.qmd). ## RLHF using Axolotl ::: {.callout-important} This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. ::: We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats. ::: {.callout-tip} You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`. ::: ### DPO Example config: ```yaml rl: dpo datasets: - path: Intel/orca_dpo_pairs split: train type: chatml.intel - path: argilla/ultrafeedback-binarized-preferences split: train type: chatml ``` DPO supports the following types with the following dataset format: #### chatml.argilla ```json { "system": "...", // optional "instruction": "...", "chosen_response": "...", "rejected_response": "..." } ``` #### chatml.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### chatml.icr ```json { "system": "...", // optional "input": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.intel ```json { "system": "...", // optional "question": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` #### chatml.ultra ```json { "system": "...", // optional "prompt": "...", "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.argilla ```json { "system": "...", // optional "instruction": "...", "chosen_response": "...", "rejected_response": "..." } ``` #### llama3.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.icr ```json { "system": "...", // optional "input": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.intel ```json { "system": "...", // optional "question": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` #### llama3.ultra ```json { "system": "...", // optional "prompt": "...", "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### zephyr.nectar ```json { "prompt": "...", "answers": [ { "answer": "...", "rank": 1 }, { "answer": "...", "rank": 2 } // ... more answers with ranks ] } ``` #### chat_template.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### chat_template.default ```yaml rl: dpo datasets: - path: ... split: train type: chat_template.default field_messages: "messages" field_chosen: "chosen" field_rejected: "rejected" message_property_mappings: role: role content: content roles: user: ["user"] assistant: ["assistant"] system: ["system"] ``` Sample input format: ```json { "messages": [ { "role": "system", "content": "..." }, { "role": "user", "content": "..." }, // ... more messages ], "chosen": { "role": "assistant", "content": "..." }, "rejected": { "role": "assistant", "content": "..." } } ``` #### user_defined.default For custom behaviors, ```yaml rl: dpo datasets: - path: ... split: train type: field_prompt: "prompt" field_system: "system" field_chosen: "chosen" field_rejected: "rejected" prompt_format: "{prompt}" chosen_format: "{chosen}" rejected_format: "{rejected}" ``` The input format is a simple JSON input with customizable fields based on the above config. ```json { "system": "...", // optional "prompt": "...", "chosen": "...", "rejected": "..." } ``` ### IPO As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. ```yaml rl: dpo dpo_loss_type: ["ipo"] ``` *Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated. ### ORPO Paper: https://arxiv.org/abs/2403.07691 ```yaml rl: orpo orpo_alpha: 0.1 remove_unused_columns: false chat_template: chatml datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned type: chat_template.argilla ``` ORPO supports the following types with the following dataset format: #### chat_template.argilla ```json { "system": "...", // optional "prompt": "...", // if available, will be taken as user message for single-turn instead of from list below // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns "chosen": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ], "rejected": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` ### KTO ```yaml rl: kto rl_beta: 0.1 # default kto_desirable_weight: 1.0 # default kto_undesirable_weight: 1.0 # default remove_unused_columns: false datasets: - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto type: llama3.ultra split: train gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true ``` KTO supports the following types with the following dataset format: #### chatml.argilla ```json { "system": "...", // optional "instruction": "...", "completion": "..." } ``` #### chatml.argilla_chat ```json { "chosen": [ {"role": "user", "content": "..."} ], "completion": [ {"role": "assistant", "content": "..."} ] } ``` #### chatml.intel ```json { "system": "...", // optional "question": "...", "completion": "..." } ``` #### chatml.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### chatml.ultra ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### llama3.argilla ```json { "system": "...", // optional "instruction": "...", "completion": "..." } ``` #### llama3.argilla_chat ```json { "completion": [ {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} ] } ``` #### llama3.intel ```json { "system": "...", // optional "question": "...", "completion": "..." } ``` #### llama3.prompt_pairs ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### llama3.ultra ```json { "system": "...", // optional "prompt": "...", "completion": "..." } ``` #### user_defined.default For custom behaviors, ```yaml rl: kto datasets: - path: ... split: train type: field_prompt: "prompt" field_system: "system" field_completion: "completion" field_label: "label" prompt_format: "{prompt}" completion_format: "{completion}" ``` The input format is a simple JSON input with customizable fields based on the above config. ```json { "system": "...", // optional "prompt": "...", "completion": "...", "label": "..." } ``` ### GRPO ::: {.callout-tip} Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code). For a comprehensive guide covering async training, custom rewards, importance sampling, and scaling, see the [GRPO deep dive](grpo.qmd). ::: In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM: ::: {.callout-important} Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`. ::: ```yaml base_model: Qwen/Qwen2.5-1.5B-Instruct vllm: host: 0.0.0.0 port: 8000 tensor_parallel_size: 2 gpu_memory_utilization: 0.85 dtype: auto # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand rl: grpo trl: use_vllm: true vllm_server_host: 0.0.0.0 vllm_server_port: 8000 vllm_server_timeout: 300 ``` ```bash CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml ``` Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute: ```bash CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2 ``` ::: {.callout-note} Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance. ::: #### Reward functions GRPO uses custom reward functions and transformations. Please have them ready locally. For example, to load OpenAI's GSM8K and use a random reward for completions: ```python # rewards.py import random def rand_reward_func(completions, **kwargs) -> list[float]: return [random.uniform(0, 1) for _ in completions] def oai_gsm8k_transform(cfg, *args, **kwargs): def transform_fn(example, tokenizer=None): label = example["answer"].split("####")[-1].strip().replace(",", "") return { "prompt": [{"role": "user", "content": example["question"]},], "answer": label, } return transform_fn, {"remove_columns": ["question"]} ``` ```yaml rl: grpo trl: beta: 0.001 max_completion_length: 256 use_vllm: True num_generations: 4 reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' reward_weights: [1.0] datasets: - path: openai/gsm8k name: main type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}' ``` To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py). #### OpenEnv Rollout Functions GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments. For example, to implement a simple math-solving environment with step-by-step verification: ```python # math_env.py import re def math_solver_rollout(model, processing_class, prompts, generation_config=None): """ Custom rollout function that generates step-by-step math solutions. Args: model: The language model processing_class: The tokenizer/processing_class prompts: List of prompt dicts (with 'messages' key for chat format) generation_config: Optional generation configuration Returns: List of completion strings """ completions = [] for prompt in prompts: # Apply chat template to prompt messages = prompt.get("messages", []) formatted_prompt = processing_class.apply_chat_template( messages, processing_class=False, add_generation_prompt=True ) # Generate step-by-step solution full_response = "" for step in range(5): # Max 5 reasoning steps current_input = formatted_prompt + full_response + "\nNext step:" inputs = processing_class(current_input, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=100, generation_config=generation_config, ) step_text = processing_class.decode( outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) # Check if solution is complete if "FINAL ANSWER:" in step_text: full_response += step_text break full_response += step_text + "\n" completions.append(full_response) return completions def math_reward(prompts, completions, answers, **kwargs): """Reward function that checks mathematical correctness""" rewards = [] for completion, correct_answer in zip(completions, answers): # Extract predicted answer match = re.search(r"FINAL ANSWER:\s*(.+)", completion) predicted = match.group(1).strip() if match else "" # Compare with correct answer reward = 1.0 if predicted == str(correct_answer) else 0.0 rewards.append(reward) return rewards def math_transform(cfg, *args, **kwargs): """Transform dataset to GRPO format with answer field""" def transform_fn(example, processing_class=None): return { "prompt": [{"role": "user", "content": example["question"]}], "answer": str(example["answer"]), } return transform_fn, {"remove_columns": ["question"]} ``` ```yaml rl: grpo trl: beta: 0.001 max_completion_length: 512 num_generations: 4 rollout_func: "math_env.math_solver_rollout" # Custom rollout function reward_funcs: ["math_env.math_reward"] reward_weights: [1.0] datasets: - path: openai/gsm8k name: main type: math_env.math_transform ``` The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives: - `model`: The language model - `processing_class`: The tokenizer/processing class - `prompts`: List of prompt dictionaries - `generation_config` (optional): Generation configuration And should return a list of completion strings. For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv). #### GRPO with DAPO/Dr. GRPO loss The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses. ```yaml trl: loss_type: dr_grpo # Normalizes loss based on max completion length (default: 256) max_completion_length: ``` For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types). #### Async GRPO Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step. ```yaml trl: use_data_producer: true # Enable data producer protocol use_vllm: true async_prefetch: true # Generate rollouts in background thread prefetch_depth: 1 # Number of rollouts to prefetch vllm_sync_interval: 2 # Sync weights to vLLM every N steps ``` ::: {.callout-note} Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled). ::: ##### vLLM LoRA Sync By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels. ```yaml adapter: lora lora_r: 32 lora_alpha: 64 lora_target_linear: true trl: vllm_lora_sync: true # Enable native LoRA sync ``` When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual: ```bash CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml ``` Then start training on a separate GPU: ```bash CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml ``` ::: {.callout-tip} LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation. ::: ##### Streaming Partial Batch Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring. ```yaml trl: streaming_partial_batch: true ``` ##### Importance Sampling Correction When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift. ```yaml trl: vllm_importance_sampling_correction: true # Enable IS correction importance_sampling_level: token # 'token' or 'sequence' off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this ``` - `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel) - `importance_sampling_level: sequence` applies per-sequence IS ratios - `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy ##### Replay Buffer The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches. ```yaml trl: replay_buffer_size: 100 # Max cached groups (0 = disabled) replay_recompute_logps: true # Recompute log-probs for replayed data (recommended) ``` ::: {.callout-note} When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data. ::: ##### Deferred Re-rolling Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them. ```yaml trl: reroll_start_fraction: 0.5 # Start re-rolling after 50% of training reroll_max_groups: 1 # Max groups to replace per batch ``` ##### Zero-Advantage Batch Skipping When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`. ```yaml trl: skip_zero_advantage_batches: true # default ``` ##### Parallel Reward Workers Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation. ```yaml trl: reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism) ``` ##### Full Async GRPO Example ```yaml base_model: Qwen/Qwen2.5-1.5B-Instruct vllm: host: 0.0.0.0 port: 8000 gpu_memory_utilization: 0.35 dtype: auto adapter: lora lora_r: 32 lora_alpha: 64 lora_target_linear: true rl: grpo trl: use_data_producer: true use_vllm: true async_prefetch: true prefetch_depth: 1 vllm_sync_interval: 2 vllm_lora_sync: true streaming_partial_batch: true vllm_importance_sampling_correction: true off_policy_mask_threshold: 0.5 importance_sampling_level: token num_generations: 8 max_completion_length: 512 reward_funcs: - rewards.accuracy_reward reroll_start_fraction: 0.5 replay_buffer_size: 100 reward_num_workers: 4 skip_zero_advantage_batches: true datasets: - path: AI-MO/NuminaMath-TIR type: rewards.prompt_transform split: train gradient_accumulation_steps: 4 micro_batch_size: 2 max_steps: 500 learning_rate: 1e-5 bf16: true gradient_checkpointing: true ``` ```bash # Terminal 1: Start vLLM on GPU 0 CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # Terminal 2: Train on GPU 1 CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml ``` ##### Multi-GPU Async GRPO Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs. **FSDP:** ```yaml fsdp: - full_shard - auto_wrap fsdp_config: fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer gradient_checkpointing_kwargs: use_reentrant: false ``` **DeepSpeed ZeRO-3:** ```yaml deepspeed: deepspeed_configs/zero3_bf16.json gradient_checkpointing_kwargs: use_reentrant: true # Required for ZeRO-3 ``` ```bash # Terminal 1: Start vLLM on GPU 0 CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # Terminal 2: Train on GPUs 0,1 CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml ``` ::: {.callout-important} With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads. ::: ### GDPO GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them. ::: {.callout-tip} Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results. ::: Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242) GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation. ```yaml base_model: Qwen/Qwen2.5-1.5B-Instruct vllm: host: 0.0.0.0 port: 8000 tensor_parallel_size: 2 gpu_memory_utilization: 0.85 rl: gdpo trl: beta: 0.001 max_completion_length: 256 use_vllm: true num_generations: 4 reward_funcs: - rewards.format_reward - rewards.correctness_reward reward_weights: [1.0, 2.0] datasets: - path: openai/gsm8k name: main type: rewards.oai_gsm8k_transform ``` You can also use GRPO with explicit aggregation control: ```yaml rl: grpo trl: multi_objective_aggregation: normalize_then_sum # GDPO behavior # or: sum_then_normalize # Default GRPO behavior ``` #### GDPO vs GRPO | Aspect | GRPO | GDPO | |--------|------|------| | **Aggregation** | `sum_then_normalize` | `normalize_then_sum` | | **Multi-reward** | May collapse advantages | Preserves reward signals | | **Single reward** | Standard behavior | Equivalent to GRPO | #### Why GDPO? When using multiple rewards with GRPO, different reward combinations can produce identical advantages: ``` # Example: format + correctness rewards [format=0, correct=3] → sum=3 [format=1, correct=2] → sum=3 ← GRPO sees these as equal! [format=2, correct=1] → sum=3 [format=3, correct=0] → sum=3 ``` GDPO normalizes each reward independently, preserving their relative differences. #### Reward Functions GDPO uses the same reward function format as GRPO: ```python # rewards.py def format_reward(completions, **kwargs) -> list[float]: return [1.0 if len(c) > 10 else 0.0 for c in completions] def correctness_reward(completions, answers, **kwargs) -> list[float]: rewards = [] for completion, answer in zip(completions, answers): # Your scoring logic here rewards.append(score) return rewards ``` #### Sequence Parallelism GDPO supports sequence parallelism for long-context training: ```yaml rl: gdpo context_parallel_size: 2 ``` ### SimPO SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. ```yaml rl: simpo rl_beta: 0.1 # default in CPOTrainer cpo_alpha: 1.0 # default in CPOTrainer simpo_gamma: 0.5 # default in CPOTrainer ``` This method uses the same dataset format as [DPO](#dpo). ### EBFT {#ebft} ::: {.callout-tip} For a detailed guide on EBFT modes, feature extraction, and configuration, see the [EBFT guide](ebft.qmd). ::: EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments. Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026) **Key advantages:** - No reward model or verifier required — works on any (prompt, completion) data - Applicable to non-verifiable tasks (code, translation, creative writing) - Operates on model rollouts (not teacher forcing), reducing distribution shift EBFT supports two modes: - **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO). - **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed. #### Structured Mode ```yaml base_model: Qwen/Qwen3-4B rl: ebft ebft: feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth embed_method: last_token use_whitening: false alignment_coef: 1.0 # Cosine similarity reward weight diversity_coef: 1.0 # Pairwise dot product penalty ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off) trl: num_generations: 4 max_completion_length: 256 temperature: 0.7 use_vllm: true vllm_server_host: 0.0.0.0 vllm_server_port: 8000 vllm_lora_sync: true # LoRA adapter sync (recommended) vllm_sync_interval: 3 use_data_producer: true async_prefetch: true # Set false for sync mode scale_rewards: true loss_type: grpo epsilon: 0.2 vllm: gpu_memory_utilization: 0.5 max_model_len: 2048 datasets: - path: nvidia/OpenCodeInstruct type: ebft_opencode.transform split: train[:500] adapter: lora lora_r: 16 lora_alpha: 32 lora_target_linear: true ``` ```bash # Terminal 1: Start vLLM CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # Terminal 2: Train CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml ``` #### Strided Mode For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU. ```yaml base_model: meta-llama/Llama-3.2-1B rl: ebft ebft: mode: strided stride: 8 context_length: 8 generate_max_len: 8 n_samples_per_prompt: 4 temperature: 0.6 feature_layers: [0.25, 0.5, 0.75] embed_method: last_token use_whitening: true alignment_coef: 1.0 diversity_coef: 1.0 rl_coef: 1.0 ce_coef: 0.03 advantage_estimator: rloo datasets: - path: nvidia/OpenCodeInstruct type: ebft_strided_structured.transform split: train[:1%] flash_attention: false flex_attention: true # Strided mode uses flex_attention gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # Required for flex_attention ``` ```bash CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml ``` ::: {.callout-tip} See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes. ::: #### EBFT Configuration Reference | Parameter | Default | Description | |-----------|---------|-------------| | `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) | | `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` | | `ebft.use_whitening` | `false` | SVD whitening of feature dimensions | | `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight | | `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight | | `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens | | `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) | | `ebft.stride` | — | Tokens between anchor points (strided mode) | | `ebft.context_length` | — | Context window per block (strided mode) | | `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) | | `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) | | `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) | ### NeMo Gym Integration [NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`). #### Single-Turn (Simplest) For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server. ```yaml base_model: Qwen/Qwen2.5-0.5B-Instruct rl: grpo chat_template: tokenizer_default trl: use_vllm: false # Colocate mode (single GPU) num_generations: 4 max_completion_length: 128 temperature: 0.9 reward_funcs: - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify plugins: - axolotl.integrations.nemo_gym.NemoGymPlugin nemo_gym_enabled: true nemo_gym_dir: ~/Gym nemo_gym_auto_start: false nemo_gym_head_port: 11000 nemo_gym_datasets: - path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl server_name: reasoning_gym datasets: - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl type: chat_template field_messages: responses_create_params.input message_field_content: content message_field_role: role ``` ```bash # Terminal 1: Start NeMo Gym resource server cd ~/Gym && .venv/bin/ng_run \ "+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \ "+skip_venv_if_present=true" # Terminal 2: Train CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml ``` ::: {.callout-note} `nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined. ::: #### Multi-Turn with Async GRPO (Recommended) For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done. ```yaml base_model: Qwen/Qwen3-0.6B rl: grpo chat_template: tokenizer_default adapter: lora lora_r: 16 lora_alpha: 32 lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj] trl: use_vllm: true vllm_mode: server vllm_server_host: localhost vllm_server_port: 8000 vllm_lora_sync: true vllm_sync_interval: 5 use_data_producer: true async_prefetch: true # 3x speedup num_generations: 4 max_completion_length: 512 temperature: 0.8 reward_funcs: - axolotl.integrations.nemo_gym.rewards.reward_env plugins: - axolotl.integrations.nemo_gym.NemoGymPlugin nemo_gym_enabled: true nemo_gym_auto_start: false nemo_gym_head_port: 11000 nemo_gym_multi_turn: true nemo_gym_verify_timeout: 120 nemo_gym_datasets: - path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl server_name: example_single_tool_call datasets: - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl type: chat_template field_messages: responses_create_params.input message_field_content: content message_field_role: role vllm: gpu_memory_utilization: 0.85 max_model_len: 2048 ``` Multi-turn requires three services running: ```bash # Terminal 1: vLLM with LoRA + tool calling VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \ python -m vllm.entrypoints.openai.api_server \ --model Qwen/Qwen3-0.6B --max-model-len 2048 \ --gpu-memory-utilization 0.85 \ --enable-lora --max-lora-rank 64 \ --enable-auto-tool-choice --tool-call-parser hermes # Terminal 2: NeMo Gym servers (resource + model proxy + agent) cd ~/Gym && .venv/bin/ng_run \ "+config_paths=[configs/axolotl_tool_calling.yaml]" \ "+skip_venv_if_present=true" # Terminal 3: Training CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml ``` ::: {.callout-important} Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format. ::: #### NeMo Gym Prerequisites ```bash # Clone and set up NeMo Gym git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym cd ~/Gym uv venv --python 3.12 && source .venv/bin/activate && uv sync # Fix pycosat build (GCC 13+) CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation ``` #### NeMo Gym Configuration Reference | Parameter | Type | Default | Description | |-----------|------|---------|-------------| | `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration | | `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo | | `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers | | `nemo_gym_head_port` | int | `11000` | Head server port | | `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` | | `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) | | `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` | #### Reward Functions | Function | Mode | Description | |----------|------|-------------| | `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward | | `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` | ### Using local dataset files ```yaml datasets: - ds_type: json data_files: - orca_rlhf.jsonl split: train type: chatml.intel ``` ### TRL auto-unwrapping for PEFT TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config: ```yaml # load ref model when adapter training. rl_adapter_ref_model: true ```