Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
3b5a9d1d88 update create_optimizer for updated api 2026-02-19 23:49:32 -05:00
Wing Lian
eb59070040 fix labels 2026-02-19 23:44:46 -05:00
Wing Lian
9722aaf7d8 fix for tokenizers change 2026-02-19 21:52:44 -05:00
Wing Lian
c5d20bbd79 integration branch for transformers#44041 2026-02-19 18:34:13 -05:00
NanoCode012
7fbedbd300 fix(doc): add limitation for unfrozen_parameters (#3416) 2026-02-19 18:32:26 -05:00
Wing Lian
145ffc9be1 upgrade transformers to 5.2.0 and torchao to 0.16.0 (#3407)
* upgrade transformers to 5.1.0 and torchao to 0.16.0

* upgrade trl for parity

* handle trl api changes

* orpo doesn't have max_prompt_len to check anymore

* cpoconfig doesn't take max_prompt_length and fix cpu offload

* slow fsdp1 test

* triton min 3.4.0 and liger to 0.7.0

* use transformers main for now for zero3 fix

* handle group_by_length change

* fix changes upstream

* mark skip flaky test

* use transformers latest release 5.2.0
2026-02-19 18:27:27 -05:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
35 changed files with 197 additions and 474 deletions

3
.gitignore vendored
View File

@@ -193,6 +193,3 @@ out/
# scm auto-versioning # scm auto-versioning
src/axolotl/_version.py src/axolotl/_version.py
# macOS
.DS_Store

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options: Configuration options:
```yaml ```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate # List of tasks to evaluate
lm_eval_tasks: lm_eval_tasks:
- arc_challenge - arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results output_dir: # Directory to save evaluation results
``` ```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details. See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4 ### delinearize-llama4

View File

@@ -1,7 +1,8 @@
base_model: google/gemma-3-4b-it base_model: google/gemma-3-4b-it
plugins: # Need to set else transformers tries to load vision too
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -29,6 +30,7 @@ lora_model_dir:
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0

View File

@@ -1,11 +1,12 @@
base_model: google/gemma-3-12b-it base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,7 +7,6 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -1,11 +1,12 @@
base_model: google/gemma-3-12b-it base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,7 +7,6 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -1,11 +1,12 @@
base_model: google/gemma-3-27b-it base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,7 +7,6 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -2,21 +2,21 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1 bitsandbytes==0.49.1
triton>=3.0.0 triton>=3.4.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
liger-kernel==0.6.4 liger-kernel==0.7.0
# END section # END section
packaging==26.0 packaging==26.0
huggingface_hub>=1.1.7 huggingface_hub>=1.1.7
peft>=0.18.1 peft>=0.18.1
tokenizers>=0.22.1 tokenizers>=0.22.1
transformers==5.0.0 transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
accelerate==1.12.0 accelerate==1.12.0
datasets==4.5.0 datasets==4.5.0
deepspeed>=0.18.3 deepspeed>=0.18.3
trl==0.27.1 trl==0.28.0
hf_xet==1.2.0 hf_xet==1.2.0
kernels==0.11.5 kernels==0.11.5
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.13.0 torchao==0.16.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1

View File

@@ -1,225 +0,0 @@
"""Merge trained text-only Gemma3 weights back into a full multimodal checkpoint.
After training with the Gemma3TextFromMultimodalPlugin, the saved checkpoint
contains only the language model weights (with ``model.language_model.*``
prefix, reversed by transformers v5's key_mapping on save).
This script reconstructs a full ``Gemma3ForConditionalGeneration`` checkpoint by
combining the trained language model weights with the original vision tower and
projector weights from the base multimodal model.
Usage::
python scripts/merge_gemma3_multimodal_weights.py \\
--original-model google/gemma-3-4b-it \\
--trained-model /path/to/trained/output \\
--output-dir /path/to/merged
"""
import argparse
import json
import logging
from pathlib import Path
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
LOG = logging.getLogger(__name__)
def collect_safetensors(model_dir: Path) -> dict[str, torch.Tensor]:
"""Load and merge all safetensors shard files in a directory."""
shard_files = sorted(model_dir.glob("*.safetensors"))
if not shard_files:
raise FileNotFoundError(f"No safetensors files found in {model_dir}")
state_dict: dict[str, torch.Tensor] = {}
for shard in shard_files:
LOG.info("Loading %s", shard.name)
state_dict.update(load_file(str(shard)))
return state_dict
def merge(
original_model: str,
trained_model: str,
output_dir: str,
*,
trust_remote_code: bool = False,
) -> None:
original_path = Path(original_model)
trained_path = Path(trained_model)
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# 1. Load the original multimodal checkpoint
LOG.info("Loading original multimodal weights from %s", original_model)
if original_path.is_dir():
original_sd = collect_safetensors(original_path)
else:
from huggingface_hub import snapshot_download
cached = Path(
snapshot_download(original_model, allow_patterns=["*.safetensors"])
)
original_sd = collect_safetensors(cached)
# 2. Load trained text-only weights (already reversed to model.language_model.* by
# transformers v5 key_mapping on save)
LOG.info("Loading trained text-only weights from %s", trained_model)
trained_sd = collect_safetensors(trained_path)
# 3. Classify original keys
lang_keys = {k for k in original_sd if k.startswith("model.language_model.")}
vision_keys = {k for k in original_sd if k.startswith("model.vision_tower.")}
projector_keys = {
k for k in original_sd if k.startswith("model.multi_modal_projector.")
}
other_keys = set(original_sd.keys()) - lang_keys - vision_keys - projector_keys
LOG.info(
"Original checkpoint: %d language, %d vision, %d projector, %d other keys",
len(lang_keys),
len(vision_keys),
len(projector_keys),
len(other_keys),
)
# 4. Classify trained keys (reverse mapping on save gives model.language_model.* prefix)
trained_lang_keys = {k for k in trained_sd if k.startswith("model.language_model.")}
trained_other = set(trained_sd.keys()) - trained_lang_keys
LOG.info(
"Trained checkpoint: %d language keys, %d other keys",
len(trained_lang_keys),
len(trained_other),
)
# 5. Build merged state dict
merged: dict[str, torch.Tensor] = {}
# Keep vision tower and projector from original
for key in vision_keys | projector_keys:
merged[key] = original_sd[key]
# Use trained language model weights (overwrite original)
for key in trained_lang_keys:
merged[key] = trained_sd[key]
# For other trained keys (like lm_head.weight), use trained version
for key in trained_other:
merged[key] = trained_sd[key]
# For any original other keys not covered by trained (shouldn't usually happen),
# keep original
for key in other_keys:
if key not in merged:
merged[key] = original_sd[key]
# Check for missing language keys that were in original but not in trained
missing_lang = lang_keys - trained_lang_keys
if missing_lang:
LOG.warning(
"%d language keys in original but not in trained; keeping original: %s",
len(missing_lang),
list(missing_lang)[:5],
)
for key in missing_lang:
merged[key] = original_sd[key]
LOG.info("Merged checkpoint: %d total keys", len(merged))
# 6. Save merged weights (sharded at 50GB, matching transformers default)
LOG.info("Saving merged weights to %s", out_path)
state_dict_split = split_torch_state_dict_into_shards(merged, max_shard_size="50GB")
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {name: merged[name] for name in tensors}
save_file(shard, str(out_path / filename))
if state_dict_split.is_sharded:
index = {
"metadata": {
"total_size": sum(t.numel() * t.element_size() for t in merged.values())
},
"weight_map": state_dict_split.tensor_to_filename,
}
with open(out_path / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
LOG.info("Saved %d shards", len(state_dict_split.filename_to_tensors))
# 7. Copy/update config
LOG.info("Writing config.json")
original_config = AutoConfig.from_pretrained(
original_model, trust_remote_code=trust_remote_code
)
# Update text_config fields from trained model's config if available
trained_config_path = trained_path / "config.json"
if trained_config_path.exists():
with open(trained_config_path) as f:
trained_config_dict = json.load(f)
# The trained config is the text sub-config; merge its fields into
# the original composite config's text_config
if hasattr(original_config, "text_config"):
for key, val in trained_config_dict.items():
if key not in ("model_type", "_name_or_path", "architectures"):
if hasattr(original_config.text_config, key):
setattr(original_config.text_config, key, val)
original_config.save_pretrained(out_path)
# 8. Copy tokenizer files from trained model if present
tokenizer_files = list(trained_path.glob("tokenizer*")) + list(
trained_path.glob("special_tokens_map*")
)
if tokenizer_files:
import shutil
for tok_file in tokenizer_files:
shutil.copy2(tok_file, out_path / tok_file.name)
LOG.info("Copied %d tokenizer files", len(tokenizer_files))
LOG.info("Merge complete. Output saved to %s", out_path)
def main():
parser = argparse.ArgumentParser(
description="Merge trained text-only Gemma3 weights back into a multimodal checkpoint."
)
parser.add_argument(
"--original-model",
required=True,
help="HuggingFace model ID or local path to the original multimodal model",
)
parser.add_argument(
"--trained-model",
required=True,
help="Local path to the trained text-only model output directory",
)
parser.add_argument(
"--output-dir",
required=True,
help="Directory to save the merged multimodal checkpoint",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=False,
help="Trust remote code when loading model config",
)
args = parser.parse_args()
merge(
original_model=args.original_model,
trained_model=args.trained_model,
output_dir=args.output_dir,
trust_remote_code=args.trust_remote_code,
)
if __name__ == "__main__":
main()

View File

@@ -246,7 +246,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters ddp_find_unused_parameters
) )
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,7 +11,6 @@ from axolotl.core.trainers import (
) )
from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback from axolotl.utils.callbacks.qat import QATCallback
@@ -53,6 +52,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}: if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1 sequence_parallel=self.cfg.context_parallel_size > 1
) )
@@ -133,21 +134,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
# Handle when max_prompt_length == max_length from defaults blocklist_args_kwargs.append("max_prompt_length")
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length # KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs = ["max_prompt_length"] blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0 self.cfg.kto_desirable_weight or 1.0
@@ -157,6 +154,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
) )
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}: elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class() training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -57,16 +57,18 @@ class AxolotlDPOTrainer(
def tokenize_row( def tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length: int | None = None,
max_completion_length, max_completion_length: int | None = None,
add_special_tokens, add_special_tokens: bool = True,
is_chat: bool = False,
) -> Dict: ) -> Dict:
res = DPOTrainer.tokenize_row( res = DPOTrainer.tokenize_row(
features, features,
processing_class, processing_class,
max_prompt_length, max_prompt_length=max_prompt_length,
max_completion_length, max_completion_length=max_completion_length,
add_special_tokens, add_special_tokens=add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self, model=None):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None and self.optimizer_cls_and_kwargs is None
): ):
return super().create_optimizer() return super().create_optimizer(model=model)
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model if model is None else model
if ( if (
not self.optimizer not self.optimizer

View File

@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like( def patch_llama_like(
self, self,
model_type: str, model_type_to_patch: str,
) -> None: ) -> None:
""" """
Generic patch for model architectures with causal lm similar to llama Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic( def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
): ):
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e
if model_type not in PATCH_FNS: if model_type_to_patch not in PATCH_FNS:
LOG.warning_once( LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type "Setting up generic cce patch for model type: %s", model_type_to_patch
) )
LOG.warning_once( LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
) )
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -1,37 +0,0 @@
# Gemma3 Text-from-Multimodal Plugin
Load a Gemma3 multimodal checkpoint (e.g. `google/gemma-3-4b-it`) directly into `Gemma3ForCausalLM` for text-only training. This bypasses the multimodal trainer path and enables sample packing and other text-specific optimizations.
## How it works
The plugin uses transformers v5's `key_mapping` parameter on `from_pretrained` to remap `model.language_model.*` checkpoint keys to `model.*`, matching what `Gemma3ForCausalLM` expects. Vision tower and projector weights are automatically discarded. On save, transformers reverses the mapping so checkpoints retain the original `model.language_model.*` prefix.
## Usage
Add the plugin to your YAML config:
```yaml
base_model: google/gemma-3-4b-it
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
```
See `examples/gemma3/gemma-3-4b-qlora.yml` for a complete example.
## Merging weights back into a multimodal checkpoint
After training, the saved checkpoint contains only the language model weights. To reconstruct a full `Gemma3ForConditionalGeneration` checkpoint (with the original vision tower and projector), use the merge script:
```bash
python scripts/merge_gemma3_multimodal_weights.py \
--original-model google/gemma-3-4b-it \
--trained-model /path/to/trained/output \
--output-dir /path/to/merged
```
This combines:
- **Trained language model weights** from your output checkpoint
- **Original vision tower + projector weights** from the base multimodal model
The merged checkpoint can be loaded as `Gemma3ForConditionalGeneration` for multimodal inference or further training.

View File

@@ -1,9 +0,0 @@
"""Gemma3 integration for loading multimodal checkpoints as text-only models."""
from .args import Gemma3TextFromMultimodalArgs
from .plugin import Gemma3TextFromMultimodalPlugin
__all__ = [
"Gemma3TextFromMultimodalArgs",
"Gemma3TextFromMultimodalPlugin",
]

View File

@@ -1,31 +0,0 @@
"""Pydantic input args for the Gemma3 text-from-multimodal plugin."""
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class Gemma3TextFromMultimodalArgs(BaseModel):
"""Configuration args for loading a Gemma3 multimodal checkpoint as text-only."""
gemma3_text_from_multimodal: bool = True
extract_text_config: bool = False
@model_validator(mode="before")
@classmethod
def set_model_type(cls, data):
if not isinstance(data, dict):
return data
if not data.get("gemma3_text_from_multimodal", True):
return data
if not data.get("model_type"):
LOG.info(
"Gemma3TextFromMultimodalPlugin: auto-setting model_type to Gemma3ForCausalLM"
)
data["model_type"] = "Gemma3ForCausalLM"
return data

View File

@@ -1,107 +0,0 @@
"""Plugin for loading Gemma3 multimodal checkpoints into Gemma3ForCausalLM (text-only).
Uses transformers v5's ``key_mapping`` parameter on ``from_pretrained`` to remap
``model.language_model.*`` keys to ``model.*``, discarding vision tower and projector
weights. On save, transformers automatically reverses the mapping so saved
checkpoints retain the original ``model.language_model.*`` prefix.
"""
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# key_mapping for transformers from_pretrained:
# Remap checkpoint keys matching ^model.language_model -> model
# Vision tower / projector keys won't match any model parameter and are discarded.
GEMMA3_KEY_MAPPING = {"^model.language_model": "model"}
class Gemma3TextFromMultimodalPlugin(BasePlugin):
"""Load a Gemma3 multimodal checkpoint as a text-only Gemma3ForCausalLM.
Hooks
-----
register(cfg)
Runs before config validation. Sets the ``_extract_text_config`` flag,
ensures ``model_type`` is ``Gemma3ForCausalLM``, and injects
``key_mapping`` into ``model_kwargs`` so that ``from_pretrained`` remaps
``model.language_model.*`` → ``model.*``.
pre_model_load(cfg)
Runs after config validation/normalization but before model instantiation.
Validates that ``model_config_type`` is ``gemma3_text`` and
``is_multimodal`` is False (confirming that ``_extract_text_config``
worked correctly).
"""
def get_input_args(self) -> str:
return "axolotl.integrations.gemma3.Gemma3TextFromMultimodalArgs"
def register(self, cfg: dict):
"""Set up config for multimodal → text-only loading.
This runs before Pydantic validation, so ``cfg`` is a raw dict.
"""
if not cfg.get("gemma3_text_from_multimodal", True):
raise ValueError(
"Gemma3TextFromMultimodalPlugin: disabled via config, but plugin selected"
)
# Flag for load_model_config() to extract the text sub-config
cfg["extract_text_config"] = True
# Ensure model_type is set for the text-only model class
if not cfg.get("model_type"):
cfg["model_type"] = "Gemma3ForCausalLM"
# Inject key_mapping into model_kwargs so from_pretrained remaps weights
model_kwargs = cfg.setdefault("model_kwargs", {})
model_kwargs["key_mapping"] = GEMMA3_KEY_MAPPING
def pre_model_load(self, cfg):
"""Validate that config extraction worked before model instantiation."""
if not getattr(cfg, "gemma3_text_from_multimodal", True):
return
if cfg.model_config_type != "gemma3_text":
LOG.warning(
"Gemma3TextFromMultimodalPlugin: expected model_config_type='gemma3_text' "
"but got '%s'. The text config extraction may not have worked.",
cfg.model_config_type,
)
if cfg.is_multimodal or cfg.processor_type:
raise ValueError(
"Multimodal mode is enabled (processor_type set), but "
"Gemma3TextFromMultimodalPlugin enabled. "
"Please disable one of the two."
)
def post_train(self, cfg, model):
"""Log merge command after training completes."""
if cfg.adapter:
LOG.info(
"Adapter training detected. To reconstruct the multimodal checkpoint:\n"
" 1. Merge adapter weights into the text-only base model:\n"
" axolotl merge_lora <your_config.yml>\n"
" 2. Then merge the resulting full model back into the multimodal checkpoint:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s/merged \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)
else:
LOG.info(
"To merge trained weights back into the multimodal checkpoint, run:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)

View File

@@ -0,0 +1,44 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -6,6 +6,12 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage ## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin - axolotl.integrations.lm_eval.LMEvalPlugin
@@ -16,9 +22,50 @@ lm_eval_tasks:
- arc_easy - arc_easy
lm_eval_batch_size: # Batch size for evaluation lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
``` ```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation ## Citation
```bib ```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec import subprocess # nosec
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from .args import LMEvalArgs as LMEvalArgs from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id, model=get_model_path(cfg),
): ):
subprocess.run( # nosec subprocess.run( # nosec
lm_eval_args, lm_eval_args,

View File

@@ -13,6 +13,21 @@ import yaml
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command( def build_lm_eval_command(
tasks: list[str], tasks: list[str],
bfloat16=True, bfloat16=True,
@@ -108,7 +123,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project, wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity, wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name, wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id, model=get_model_path(cfg),
revision=cfg.revision, revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template, apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn, fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -10,6 +10,7 @@ from functools import cached_property
import addict import addict
import transformers import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
@@ -500,6 +501,7 @@ class PatchManager:
and not self.cfg.trust_remote_code and not self.cfg.trust_remote_code
and not self.cfg.gptq and not self.cfg.gptq
and self.cfg.flash_attention and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference and not self.inference
): ):
# TODO(MengqingCao): split these patches separately # TODO(MengqingCao): split these patches separately

View File

@@ -204,13 +204,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
check_model_config(cfg, model_config) check_model_config(cfg, model_config)
# Extract text config from composite config when explicitly requested
# (set by plugins like Gemma3TextFromMultimodalPlugin)
if getattr(cfg, "extract_text_config", False) and hasattr(
model_config, "get_text_config"
):
model_config = model_config.get_text_config()
return model_config return model_config

View File

@@ -59,7 +59,12 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True hidden_states.requires_grad = True
with torch.enable_grad(): with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args) output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
torch.autograd.backward(output, dY) torch.autograd.backward(output, dY)
return ( return (
None, None,

View File

@@ -28,8 +28,12 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
} }
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" ORIGINAL_MAYBE_CODE = (
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" "tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
def check_evaluation_loop_is_patchable() -> bool: def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -446,7 +446,16 @@ class AxolotlInputConfig(
}, },
) )
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
sequence_len: int = Field( sequence_len: int = Field(
default=512, default=512,

View File

@@ -247,7 +247,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=F
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3", "gemma3_text"] drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask: if drop_attn_mask:
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")

View File

@@ -300,7 +300,6 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl) self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific # ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer): def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -186,6 +186,7 @@ class TestFSDP1:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {

View File

@@ -365,6 +365,7 @@ class TestFSDP2:
verify_training_success(temp_dir) verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0 @require_torch_2_7_0
def test_dpo_fft(self, temp_dir): def test_dpo_fft(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -115,6 +115,9 @@ class TestAssistantChatTemplateLlama3:
def test_phi35(self, phi35_tokenizer, assistant_dataset): def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset") LOG.info("Testing phi-3.5 with assistant dataset")
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
"phi35 tokenizer should be a LlamaTokenizer"
)
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
@@ -140,13 +143,13 @@ class TestAssistantChatTemplateLlama3:
# fmt: off # fmt: off
expected_input_ids = [ expected_input_ids = [
32010, # user 32010, # user
22172, 32007, # user eot 12199, 32007, # user eot
32001, # assistant 32001, # assistant
22172, 32007, # assistant eot 12199, 32007, # assistant eot
32010, # user 32010, # user
1781, 26966, 32007, # user eot 16773, 26966, 32007, # user eot
32001, # assistant 32001, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
expected_labels = [ expected_labels = [
-100, # user -100, # user
@@ -156,7 +159,7 @@ class TestAssistantChatTemplateLlama3:
-100, # user -100, # user
-100, -100, -100, # user eot -100, -100, -100, # user eot
-100, # assistant -100, # assistant
1781, 26966, 32007, # assistant eot 16773, 26966, 32007, # assistant eot
] ]
# fmt: on # fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}") LOG.debug(f"Expected input_ids: {expected_input_ids}")

View File

@@ -84,7 +84,8 @@ class TestTokenizers:
} }
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404] assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert len(tokenizer) == 32001 assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length # ensure reloading the tokenizer again from cfg results in same vocab length