Two bugs in ``AsyncGRPOTrainer._maybe_sync_vllm_weights`` plus a
companion bug in the sync-hook patch site that together neutralized
LoRA weight sync entirely whenever ``async_prefetch=False`` was
combined with NeMo Gym's data-producer path:
1. ``_maybe_sync_vllm_weights`` had ``if not async_prefetch: return``
at the top. The original design assumed sync mode would fall back
to TRL's stock per-step ``sync_weights`` call inside
``_generate_single_turn`` — true for vanilla GRPO but FALSE in
NeMo Gym multi-turn, where ``NemoGymDataProducer`` calls the agent
server directly and ``_generate_single_turn`` is never invoked.
Result: no sync ever happened in NeMo Gym sync mode.
2. ``step % vllm_sync_interval`` would TypeError on the first call if
``vllm_sync_interval`` was unset (the default for any config that
doesn't explicitly set it).
3. The ``_generate_single_turn`` patch installed
``vllm_generation.sync_weights = lambda: None`` unconditionally
for vllm_lora_sync runs. That's correct in async-prefetch mode
(BG thread can't safely sync) but wrong in sync mode: TRL's
per-step auto-sync inside ``_generate_single_turn`` was the
fallback that the early return in (1) was assuming, and the
no-op patch was killing it.
Fix:
- Drop the ``not async_prefetch`` early return; ``_maybe_sync_vllm_weights``
is now the canonical sync trigger and runs in both modes from
``_prepare_inputs_with_data_producer`` / ``_prepare_inputs_legacy_async``.
- Default ``vllm_sync_interval`` to 1 when unset.
- In the ``_generate_single_turn`` patch, route sync_weights to
``_sync_lora_adapter`` in sync mode (and keep the lambda no-op
in async mode for the BG-thread safety reason).
The plugin used to unconditionally monkey-patch
VLLMClient.init_communicator to a no-op AND silently no-op
sync_weights when vllm_lora_sync was off. Combined, this turned the
trainer into a functional no-op whenever (a) the user ran NeMo Gym
+ LoRA without remembering to set vllm_lora_sync=true or (b) the
user ran NeMo Gym + full fine-tune (which had no working sync path
under the old code).
Replace both patches with:
1. A probe of the configured vLLM server's /openapi.json at
pre_model_load. Three transports are recognized:
- NCCL (/init_communicator/ + /update_named_param/) — TRL serve
and axolotl vllm-serve both expose this
- LoRA filesystem (/v1/load_lora_adapter or /set_lora_adapter/)
- HTTP base64 full-weight (/http_update_weights/) — axolotl
vllm-serve only
2. A pure-logic ``select_weight_sync_transport`` that picks the
right one for (server caps × adapter type).
3. ``init_communicator`` is only patched out when the server has no
NCCL routes; against TRL/axolotl serve modules it stays live so
full-finetune NCCL sync works.
4. ``post_trainer_create`` uses the selection table to install LoRA
filesystem sync OR leave the standard NCCL flow alone OR raise
NotImplementedError (HTTP — pending) OR raise a precise diagnosis
when no transport is viable. No more silent no-op trainers.
Surfaces a class of GRPO config errors at axolotl-train startup instead
of letting them bubble out of GRPOTrainer.__init__ after the model loads.
Three checks under RLValidationMixin.check_grpo_batch_size_divisibility:
- effective generation_batch_size (or mb*GA fallback) must be divisible
by trl.num_generations, with a hint pointing at the smallest GA bump
that fixes the violation
- num_generations >= 2 (group-relative advantage needs variance; with
num_gen=1 the policy never updates)
- When world_size > 1, effective gbs >= num_generations * world_size
11 unit tests cover the table: divisible/non-divisible, explicit and
implicit gbs, multi-rank constraint, GRPO-disabled passthrough, and
unset num_generations.
* feat: support excess_length_strategy for RL trainers
Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.
- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
within sequence_len while preserving the full prompt, then decodes
back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len
Closes#3547
* improve RL truncation strategy robustness and performance
---------
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* Skip redundant evaluation when resuming from checkpoint
* add condition check for adding callback
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* better handling of dora merge on Conv layers in Qwen 3.5
* address issues from code review
* stricter efficient merges for dora since we now have meta model to reference
* upgrade to torchao 0.17.0
* upgrade mistral-common too
* chore: lint
* patch fix for torchao low bit optimizers
* fix up
* propagate dtype
* fix test for ao change
* address PR comments
* feat: add sonicmoe fused lora support
* fix: forgot to add file
* feat: add test
* feat: add lora support for other routes
* fix: add int8 lora support
* fix: add qwen35_moe interleave support
* fix: qwen3_5_moe loss
* chore: lint
* address some pr comments
* fix test imports
* add support matrix for moe kernels [skip ci]
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]
- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
argilla_chat() so datasets with tool-call messages no longer raise
KeyError: 'tool' (closes#3217)
- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
overrides content instead of tokenizer.name_or_path when added_tokens_overrides
is set, preventing cache busting when only output_dir changes (closes#3303)
- Add three Pydantic config validators to AxolotlConfigWCapabilities:
* save_strategy: 'best' requires metric_for_best_model
* streaming=True is incompatible with val_set_size > 0
* lora_target_modules list entries must be valid Python regex patterns
- Tests for all three changes
* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash
* chore: lint
* move the validators out of the w/ capabilities schema
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
* qat patch
* tests fixes
* fixup per PR code review
* use state dict hooks to handle dequant for saving safetensors from transformers
* use transformers torch ao quantizer hooks to save mx quantized model
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* Add precompute_ref_log_probs to config schema
* chore: add description for config
* Add test for precompute_ref_log_probs and move to training args
* useing precompute logprobs as the default slows down CI as it has to precompute
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
* allow bf16 flag but warn
Reason: when doing e.g. LoRA merges with CUDA_VISIBLE_DEVICES=, this will unnecessarily crash, even though the LoRA merge operation would have finished successfully. This seems to warrant changing it to a warning instead, as the code will most likely crash later if bf16 is unavailable and training begins anyway.
* don't use deprecated LOG.warn
* update tests to reflect validation change
* Deperecate dpo_norm_loss
* Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179
* Remove deprecated rpo_alpha
* Remove dead_code tokenize_row
* Add _tokenize override to prevent double bos token on Llama DPO
* Fix DPO loss type now list not string
* Linting fix
* PR fixes
* update _tokenize override for DPO for multimodal
* support flattening/packing for GRPO
* more flattening
* fix tests
* improve dead vllm handling
* refactor out process handling for vllm serve and move bench flattening tests to gpu tests
* add validation for flattening with liger
* isolate batch flattening test
* flaky test
* nemo gym integration with grpo wip
* mostly working
* cleanup
* simplify
* update docs
* nemo gym support wip
* cleanup
* chore: lint
* address PR review and add more tests
* chore: lint
* post merge lora fixes for CI (#3536) [skip ci]
* post merge lora fixes for CI
* handle lora kernel auto-enable for moe without grouped_mm
* prefer not to import torch in schema validation
* address pr comments, add timeout, add tests
* roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions
* remove typo
* update qwen3.5 moe 35b-a3b yaml for 5090
* more bug fixes
* fix tests to match updated trainer
* don't use fa2 for hooks test
* reset plugins on the instance
* retry download
* fix references to renamed axolotl_cfg property on trainer
* Fix ref to trainer cfg
* fix: robust handling of race condition on patching check (#3543) [skip ci]
* EBFT: Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models (#3527) [skip ci]
* EBFT wip
* fixes
* more fixeS
* add missing strided module
* ebft fixes for multi-turn
* make ebft work with async
* add example for ebft w qwen3.5
* fix for split thinking and update yaml for lora over linear attention only
* enforce_eager for vllm arg in schema
* fix sync weights
* fix multi-gpu
* handle updated sig for mm
* ddp fixes
* improve multi-gpu handling, don't calculate logits, adaptive completion length
* chore: lint
* chore: lint
* support completion_mean
* Address corereview feedback
* clamp min IS ratio
* Address PR code review
* more fixes identified
* address code review
* Fix property from rebase conflict
* fix for ebft sync and update docs
* make trainer loss patch check a solo test
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* EBFT wip
* fixes
* more fixeS
* add missing strided module
* ebft fixes for multi-turn
* make ebft work with async
* add example for ebft w qwen3.5
* fix for split thinking and update yaml for lora over linear attention only
* enforce_eager for vllm arg in schema
* fix sync weights
* fix multi-gpu
* handle updated sig for mm
* ddp fixes
* improve multi-gpu handling, don't calculate logits, adaptive completion length
* chore: lint
* chore: lint
* support completion_mean
* Address corereview feedback
* clamp min IS ratio
* Address PR code review
* more fixes identified
* address code review
* Fix property from rebase conflict
* roundup_power2_divisions not needed with newer pytorch versions
* remove typo
* update qwen3.5 moe 35b-a3b yaml for 5090
* more bug fixes
* fix tests to match updated trainer
* don't use fa2 for hooks test
* reset plugins on the instance
* retry download
* fix references to renamed axolotl_cfg property on trainer
* Fix ref to trainer cfg
* feat: LoRA kernel support for bias, dropout, dora, embeddings
* chore: lint
* chore: lint
* address PR feedback, add regression tests, add fsdp2 tests for lora kernels
* update tests for new sigs
* update tests now that bias and dropout are supported
* fix token state json and mistral tokenizer issue
* centralize constants
* forgot to commit constants file
* Fix weakref in pickling relora state dict
* make curl a bit quieter so it doesn't log 2K lines
* fix path traversal for olmoe test
* more test fixes that weren't flagged previously
* chore: lint
* skip tests that fail b/c of OutOfResources
* scattermoe as slow tests
* update fbgemm-genai for torch 2.10
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:
- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py
Closesaxolotl-ai-cloud/axolotl#3331
* optimize moe + lora
* more scattermoe optims
* selective dequant
* add correctness unit tests and benchmarks for scattermoe + lora
* handle base+lora split kernel for older moe models
* chore: lint
* fix casting for H200 and B200
* register pressure estimation and pruning for h200/b200
* use soft limit for pruning
* qkv patch for qwen3.5moe
* support text_model for qwen3.5 moe
* nesting of qwen3
* use udpated cce with zero3 support
* Fix decomposed backward for QKV and O projections
eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
* use custom triton kernels for entropy from logits and selective softmax
* PR comments fixes
* fix out of bounds, include tests, include benchmarks
* chore: lint
* async grpo support
* implement data producer
* use fast async
* handle call to create data producer
* fix liger kernel setup
* fix replay buffer
* chore: lint
* make gpus go brrr
* chore: lint
* inplace div_, unwrap model for logits in bf16
* fuse selective softmax and empty cuda cache on each scoring step
* remove waiting for synch time and fix race
* make fp8 work and allow lora kernels w rl
* grpo with lora vllm sync and fixes for sharded distributed
* update docs
* more patches so it works against trl main
* address PR feedback for corerabbit
* consolidate behavioud of routing in scattermoe kernels
* collect telemetry on best chosen autotuned kernel
* properly collect data
* Fix property name and get smem too
* handle issues raised by coderabbit
* add tests for parity before refactoring
* docs: fix codestyle placeholders in CONTRIBUTING.md
Replace unresolved {codestyle} and {URLofCodestyle} template
variables with Ruff, the project's actual linter/formatter
as configured in .pre-commit-config.yaml.
* fix: replace bare except clauses with specific exception types
- quantization.py: use except ImportError for optional torchao imports
(consistent with line 48 which already uses ImportError correctly)
- cli/config.py: use except (RuntimeError, AssertionError) for CUDA
device property query
Prevents masking unrelated errors like KeyboardInterrupt or SystemExit.
* test: add unit tests for convert.py JSON/JSONL utilities
Cover FileReader, FileWriter, StdoutWriter, JsonParser,
JsonlSerializer, and JsonToJsonlConverter with 8 test cases
including roundtrip and edge case (empty list) scenarios.
Previously this module had zero test coverage.
* fix: address CodeRabbit review feedback
- quantization.py: catch (ImportError, RuntimeError) for optional
torchao imports; CUDA wheel/GPU mismatches raise RuntimeError,
not ImportError
- convert.py: remove unused output_file_path parameter from
JsonToJsonlConverter.convert() — FileWriter already holds the
output path from construction
- tests/test_convert.py: update call site to match new signature
* install flash-linear-attention
* handle prequant weights for fsdp2 and ensure loss is not zero
* fix type for cu_seqlen, uninstall causal_conv1d
* chore: lint
* uv pip uninstall doesn't need confirmation
* upgrade transformers==5.3.0 trl==0.29.0 kernels
* use latest deepspeed fixes
* use corect image for cleanup
* fix test outputs for tokenizer fixes upstream
* fix import:
* keep trl at 0.28.0
* handle updated API
* use latest trl since 0.28.0 doesn't work with latest transformers
* use trl experimental for pad to length
* monkeypatch trl with ORPOTrainer so liger doesn't croak
* upgrade accelerate
* more fixes
* move patch for orpotrainer
* load the imports later
* remove use_logits_to_keep
* fix loss_type arg as a list
* fetch hf cache from s3
* just manually download the missing model for now
* lint for pre-commit update
* a few more missing models on disk
* fix: loss_type internally now list
* fix: remove deprecated code and raise deprecate
* fix: remove unneeded blocklist
* fix: remove reliance on transformers api to find package available
* chore: refactor shim for less sideeffect
* fix: silent trl experimental warning
---------
Co-authored-by: NanoCode012 <nano@axolotl.ai>
* mxfp4 axo
* import lint
* test for qat mxfp4
* config for mxfp4
* add qat:
* pass base config
* MXFakeQuantizeConfig
* lint
* tune config so it fits in 32GB VRAM
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>