- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1].
Triton forward/backward take an n_rot runtime arg that restricts
rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only
pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin
that broadcast over batch.
- Forward: _make_fused_forward used a stale shared_kv_states kwarg the
current decoder layer no longer passes. Now mirrors stock attention,
reading/writing past_key_values.shared_layers.
The LoRA vllm-serve wrapper only exposed /v1/chat/completions, but
retrace's SWE agent server uses the token-id-aware /v1/completions
endpoint so it can feed raw prompt_token_ids + track per-token
logprobs across multi-turn rollouts. Add the route, mirroring the
shape of /v1/chat/completions but routing to the vLLM worker's
generate() method so prompt_token_ids are passed through as-is.
Also add a worker_pipe_lock around conn.send/conn.recv. The
multiprocessing.Connection to the vLLM worker is a single shared
full-duplex pipe; concurrent HTTP requests interleave pickle frames
on the wire and corrupt the stream (observed as
UnpicklingError: pickle data was truncated, surfacing as 500s).
The agent server fires ~8 concurrent rollout requests at once, so
this was a hard blocker for any multi-concurrent workload. Serialize
access to the pipe per-request round-trip.
Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:
1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
at the top. The original design assumed sync mode would fall back
to TRL's stock per-step ``sync_weights`` call inside
``_generate_single_turn`` — true for vanilla GRPO but FALSE in
NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
server directly and ``_generate_single_turn`` is never invoked.
Result: no sync ever happened in NeMo Gym sync mode.
2. ``step % vllm_sync_interval`` would TypeError on the first call if
``vllm_sync_interval`` was unset (the default for any config that
doesn't explicitly set it).
3. The ``_generate_single_turn`` patch installed
``vllm_generation.sync_weights = lambda: None`` unconditionally
for vllm_lora_sync runs. That's correct in async-prefetch mode
(BG thread can't safely sync) but wrong in sync mode: TRL's
per-step auto-sync inside ``_generate_single_turn`` was the
fallback that the early return in (1) was assuming, and the
no-op patch was killing it.
Fix:
- Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
is now the canonical sync trigger and runs in both modes from
``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
- Default ``vllm_sync_interval`` to 1 when unset.
- In the ``_generate_single_turn`` patch, route sync_weights to
``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
in async mode for the BG-thread safety reason).
``get_server_configs`` was hardcoded to a 5s timeout with no retry.
That's empirically too tight to survive a kill-and-relaunch cycle:
when the agent server is finishing in-flight rollouts from a prior
run, it can take 10-30s to respond to /global_config_dict_yaml, and
the trainer would crash at startup with a ReadTimeoutError.
Bump the per-attempt timeout to 30s and retry up to 3 times with a
2s/4s backoff. The retry intentionally raises a RuntimeError after
the third failure rather than returning empty config — silent
failure here would let training proceed with no agent servers
discovered, which is also a no-op trainer.
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).
Replace both patches with:
1. A probe of the configured vLLM server's /openapi.json at
pre_model_load. Three transports are recognized:
- NCCL (/init_communicator/ + /update_named_param/) — TRL serve
and axolotl vllm-serve both expose this
- LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
- HTTP base64 full-weight (/http_update_weights/) — axolotl
vllm-serve only
2. A pure-logic ``select_weight_sync_transport`` that picks the
right one for (server caps × adapter type).
3. ``init_communicator`` is only patched out when the server has no
NCCL routes; against TRL/axolotl serve modules it stays live so
full-finetune NCCL sync works.
4. ``post_trainer_create`` uses the selection table to install LoRA
filesystem sync OR leave the standard NCCL flow alone OR raise
NotImplementedError (HTTP — pending) OR raise a precise diagnosis
when no transport is viable. No more silent no-op trainers.
Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:
- effective generation_batch_size (or mb*GA fallback) must be divisible
by trl.num_generations, with a hint pointing at the smallest GA bump
that fixes the violation
- num_generations >= 2 (group-relative advantage needs variance; with
num_gen=1 the policy never updates)
- When world_size > 1, effective gbs >= num_generations * world_size
11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
* bump transformers to 5.5.4 and trl to latest 1.1.0
* more upgrades
* update peft too
* adapt lora_merge to peft 0.19 layer config API
PEFT 0.19 requires a LoraConfig object on Linear/ParamWrapper/Conv
layer constructors and moved use_rslora, use_dora, fan_in_fan_out,
lora_dropout, and lora_bias into that config. Build the config
per branch in _build_peft_layer_and_get_delta so the merge utility
works with the upgraded peft.
* allow lora_dropout on mixed attention+MoE configs under peft 0.19
PEFT 0.19's convert_peft_config_for_transformers auto-remaps old MoE
target_modules (w1/w2/w3 on Mixtral, etc.) into target_parameters for
transformers v5's fused 3D expert Parameters. Those targets get wrapped
with ParamWrapper, which rejects lora_dropout != 0 because the 3D
einsum can't factor dropout out of lora_B(lora_A(dropout(x))).
Monkeypatch ParamWrapper.__init__ to internally use a copy of the
LoraConfig with lora_dropout=0, so its dropout slot becomes nn.Identity
while the shared config still delivers real dropout to sibling Linear
LoRA layers (attention q/k/v/o). A probe runs the same conversion on a
deep copy to detect the situation and emit a warning before patching.
* fix: rename model to adapter_model for fsdp sharded final model
* fix: follow upstream transformer shard size
* fix: handle multiple model files
* fix redundant condition, tighten to safetensors, keep shard size small
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* feat: support excess_length_strategy for RL trainers
Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.
- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
within sequence_len while preserving the full prompt, then decodes
back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len
Closes#3547
* improve RL truncation strategy robustness and performance
---------
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with
FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning.
Made-with: Cursor
* Skip redundant evaluation when resuming from checkpoint
* add condition check for adding callback
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* better handling of dora merge on Conv layers in Qwen 3.5
* address issues from code review
* stricter efficient merges for dora since we now have meta model to reference
* qwen3_5.jinja: handle list content on system messages
The system message branch used string concatenation on
messages[0].content, which breaks when the first system message uses
the OpenAI-style list-of-parts format that multimodal datasets require.
User and assistant branches already handle both string and list content,
but the system branch did not.
Check whether content is a string and fall back to iterating over parts
when it is a list, matching the pattern used for user messages.
Fixes#3590
* Address pr for other content types
---------
Co-authored-by: Joaquin Hui Gomez <joaquinhuigomez@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: remove unneeded debug log
* fix: cleanup
* feat: add dense gemma config and cleanup
* feat: add cce support
* update notes and set torch compile
* fix patch for new number of return vals
* fixes for gemma4
* fix packing bug
* use updated cce for mm
* fix: pass in kv cache func when avail for transformers 5.5
* feat: update examples with flex variant and readme
* gemma4 lora attention kernels
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* upgrade to torchao 0.17.0
* upgrade mistral-common too
* chore: lint
* patch fix for torchao low bit optimizers
* fix up
* propagate dtype
* fix test for ao change
* address PR comments
* feat: add sonicmoe fused lora support
* fix: forgot to add file
* feat: add test
* feat: add lora support for other routes
* fix: add int8 lora support
* fix: add qwen35_moe interleave support
* fix: qwen3_5_moe loss
* chore: lint
* address some pr comments
* fix test imports
* add support matrix for moe kernels [skip ci]
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* docs: comprehensive documentation improvements for humans and agents
New human docs:
- grpo.qmd: GRPO deep dive (async, rewards, IS correction, scaling)
- ebft.qmd: EBFT guide (structured/strided modes, feature extraction)
- choosing_method.qmd: decision tree for SFT vs LoRA vs DPO vs GRPO
- vllm_serving.qmd: vLLM setup for GRPO (server/colocate, LoRA sync)
- training_stability.qmd: monitoring, NaN debugging, OOM, healthy metrics
New agent docs:
- AGENTS_SFT.md: agent reference for supervised fine-tuning
- AGENTS_DPO.md: agent reference for preference learning (DPO/KTO/ORPO)
Updated existing docs:
- rlhf.qmd: cross-references to new GRPO/EBFT/choosing-method guides
- getting-started.qmd: reorganized Next Steps with links to new guides
- debugging.qmd: link to training stability guide
- _quarto.yml: added new pages to sidebar navigation
Removed:
- bak.agents.md: stale backup that confused agents
* docs: trim duplicated generic config from AGENTS_DPO.md
Remove boilerplate training params (optimizer, gradient_checkpointing,
flash_attention, etc.) from each method template. These are not
preference-learning-specific and are already covered in AGENTS_SFT.md.
Config templates now show only method-specific fields with a reference
to AGENTS_SFT.md for the rest.
* docs: deduplicate across new doc pages
- grpo.qmd: collapse vLLM setup section to brief config + link to
vllm_serving.qmd; collapse IS correction to essentials + link;
replace full monitoring tables with summary + link to
training_stability.qmd
- vllm_serving.qmd: remove duplicated async/IS config reference tables
(already in grpo.qmd config reference); replace full example config
with link to grpo.qmd quick start
- ebft.qmd: trim generic training params in quick start config
* fix: train scripts
* feat: split files into cleaner parts
* fix: cleanup pretraining docs
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]
- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
argilla_chat() so datasets with tool-call messages no longer raise
KeyError: 'tool' (closes#3217)
- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
overrides content instead of tokenizer.name_or_path when added_tokens_overrides
is set, preventing cache busting when only output_dir changes (closes#3303)
- Add three Pydantic config validators to AxolotlConfigWCapabilities:
* save_strategy: 'best' requires metric_for_best_model
* streaming=True is incompatible with val_set_size > 0
* lora_target_modules list entries must be valid Python regex patterns
- Tests for all three changes
* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash
* chore: lint
* move the validators out of the w/ capabilities schema
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* qat patch
* tests fixes
* fixup per PR code review
* use state dict hooks to handle dequant for saving safetensors from transformers
* use transformers torch ao quantizer hooks to save mx quantized model
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* Add precompute_ref_log_probs to config schema
* chore: add description for config
* Add test for precompute_ref_log_probs and move to training args
* useing precompute logprobs as the default slows down CI as it has to precompute
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* allow bf16 flag but warn
Reason: when doing e.g. LoRA merges with CUDA_VISIBLE_DEVICES=, this will unnecessarily crash, even though the LoRA merge operation would have finished successfully. This seems to warrant changing it to a warning instead, as the code will most likely crash later if bf16 is unavailable and training begins anyway.
* don't use deprecated LOG.warn
* update tests to reflect validation change