Compare commits

..

33 Commits

Author SHA1 Message Date
NanoCode012
be1f8db913 Merge branch 'main' into feat/glm45 2025-12-25 17:50:09 +07:00
Abubakar Abid
f2155eaf79 feat: add trackio as experiment tracking integration (#3253)
* feat: add trackio as experiment tracking integration

- Add TrackioConfig to integrations schema with project_name, run_name, and space_id
- Create trackio_.py module for environment setup
- Add is_trackio_available() utility function
- Integrate trackio with report_to in trainer builder
- Add trackio callback for experiment tracking
- Add trackio config keys to gpt-oss example YAMLs
- Trackio runs locally by default, syncs to HF Space if space_id provided

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* Update requirements.txt

* don't allow pydantic 2.12 for now

---------

Co-authored-by: Abubakar Abid <aaabid93@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-23 08:49:07 -05:00
kallewoof
92ee4256f7 feature: raise on long sequence drop (#3321)
* feature: raise on long sequence drop

It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.

* tests: add excess_length_strategy tests

* doc: updated return value description for drop_long_seq_in_dataset

* add @enable_hf_offline

* fixed cfg modified after validate_config called

* hf offline fix

* fix tqdm desc when raise is used

* test: added test for non-batched case

* accidental code change revert

* test: use pytest.raises

* test: simplified drop_seq_len tests

* test: moved excess_length_strat test to test_data.py

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-12-22 13:59:49 -05:00
Wing Lian
efeb5a4e41 fix check for fp8 capability (#3324)
* fix check for fp8 capability

* handle non-cuda compute

* reduce concurrency of tests
2025-12-22 13:58:25 -05:00
VED
faaff6c792 allow users to set ndigits for rounding of metrics when logging (#3325)
* METRIC_PRECISION-> 8

* use ndigits and move env getter to top of log function

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-22 08:54:43 -05:00
Alexander Kozhevnikov
43cef27458 Fix typo in densemixer RuntimeError (#3327) [skip ci]
It offers installing densemizer while it should be densemixer
2025-12-22 08:53:58 -05:00
Wing Lian
07c41a6c2a fix preview docs failing due to running out of disk (#3326) [skip ci]
* fix preview docs failing due to running out of disk

* fix docs publish too
2025-12-19 11:34:55 -05:00
salman
bbd3486f57 Distributed Muon Optimizer (#3264)
* init

* working

* updating configs

* removing unneeded files

* lint

* comments

* lint

* fix regex match

* bump contribs version

* comments

* fixing tests and imports

* muon imports in test v2

* test cleanup

* bump contribs version

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2025-12-19 10:43:47 -05:00
VED
3750d7dd64 add liger support kernal for dpo (#3302)
* add liger kernal 4 dpo

* revert grpo changes,add support in dpo

* revert grpo changes,add support in dpo

* dpo_use_liger_kernal

* fix liger_dpo

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-18 11:11:06 -05:00
xzuyn
2197b0bf89 feat: cheap ppl metric (#3317)
* Import math and compute perplexity from loss values

* lint

* coderabbit changes

* lint

* fix: add rounding to ppl

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-18 09:02:41 -05:00
Seung Hyun Cho
3e51a680c2 fix: Fix evaluation loss in KD trainer (#3271)
* fix: Fix evaluation loss in KD trainer

* Fix v2 strategy super() call

* fix: Add safety check for total_tokens in log method

* fix: simplified num items and outputs return handling

* fix: add missing model forward pass in compute_loss

* refactor: Use Template Method pattern for chat template strategies

* refactor: use pop(None) and remove v2 override

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-17 13:40:36 -05:00
xzuyn
2cf254b4af Add peft_autocast_adapter_dtype config option (#3311) [skip ci]
* Add `peft_autocast_adapter_dtype` field to schema

* Add `autocast_adapter_dtype` to `model_kwargs`

* chore: docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-12-17 10:09:39 -05:00
salman
83d4d97dcc Add QAT NVFP4 configs for blogpost (#3280) [skip ci]
* add configs for blogpost

* fix configs

* fixing baseline configs
2025-12-17 09:35:22 -05:00
NanoCode012
a1d07f42e4 Fix(misc): address PYTORCH_CUDA_ALLOC_CONF deprecate (#3313)
* fix: leftover ministral docs changes

* fix: pytorch_cuda_alloc_conf deprecation

* fix: set old PYTORCH_CUDA_ALLOC_CONF env too

* handle 2.9 separately

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-17 09:12:18 -05:00
Wing Lian
2a664dc8ad support for xformers wheels for torch 2.9 (#3308)
* support for xformers wheels for torch 2.9

* fix hf cache?

* don't use hf cache from s3

* show disk free space in ci
2025-12-11 11:56:40 -05:00
NanoCode012
4ac78aa562 fix: update qwen3 jinja tokenization off a few tokens (#3295)
* fix: update qwen3 jinja tokenization off a few tokens

* fix: add note on tokenization issue

* fix: pop last index for mistral tokenizer
2025-12-09 14:31:03 +07:00
VED
b3f4aa149f fix bin size (#3307)
* fix bin size

* lint

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
2025-12-08 09:16:18 -05:00
salman
75b20fb66f Save processor in quantizer CLI (#3290) 2025-12-06 16:27:18 +00:00
NanoCode012
5992e607a2 fix: improve ministral3 docs to be clearer (#3300)
* fix: improve ministral3 docs to be clearer

* fix: title

* chore: wording
2025-12-04 21:44:44 +07:00
NanoCode012
2b66ee189c Feat: add ministral3 (#3297)
* feat: add ministral and mistral3

* chore: lint

* feat: update cce for ministral

* fix: add vram usage

* feat: update for release

* fix: save_pretrained issue in v5

* fix: add instructions to use v5 branch

* fix: add to multipack

* fix: improve instructions

* fix: add model to readme
2025-12-04 08:32:08 -05:00
NanoCode012
86d8cca149 Feat: add trinity by ArceeAI (#3292) 2025-12-02 13:12:55 -05:00
NanoCode012
4a0f98e612 feat: upgrade liger to 0.6.4 (#3289) 2025-12-02 09:16:23 -05:00
NanoCode012
a526647b31 Merge branch 'main' into feat/glm45 2025-11-28 13:41:25 +07:00
NanoCode012
8069177284 Merge branch 'main' into feat/glm45 2025-11-10 21:41:05 +07:00
NanoCode012
a28eb600e9 feat: add readme and better examples 2025-08-13 13:57:15 +07:00
NanoCode012
4b16f363bc fix: move 2025-08-13 10:46:42 +07:00
NanoCode012
272a456ec0 fix: remove lora in fft config 2025-08-12 20:34:47 +07:00
NanoCode012
7e83268662 feat: add wip fft offload config 2025-08-12 20:34:47 +07:00
NanoCode012
b2a8c37a27 fix: use smaller model 2025-08-12 20:34:47 +07:00
NanoCode012
603166d9c5 feat: add example config 2025-08-12 20:34:47 +07:00
NanoCode012
e8c9517ac8 feat: add to multipack 2025-08-12 20:34:47 +07:00
NanoCode012
0bbad9202c feat: add glm4moemoe to z3 2025-08-12 20:34:47 +07:00
NanoCode012
cb042e9775 feat: add cce for glm4_moe & deepseek v3 2025-08-12 20:32:46 +07:00
82 changed files with 2431 additions and 138 deletions

View File

@@ -12,6 +12,9 @@ jobs:
build-deploy:
runs-on: ubuntu-latest
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Quarto

View File

@@ -11,6 +11,7 @@ on:
- '_quarto.yml'
- docs/scripts/generate_config_docs.py
- src/axolotl/utils/schemas/**.py
- .github/workflows/preview-docs.yml
permissions:
checks: write
@@ -27,6 +28,10 @@ jobs:
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
with:

View File

@@ -66,12 +66,12 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -113,9 +113,13 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
@@ -145,12 +149,12 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p ~/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
#
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -188,11 +192,11 @@ jobs:
axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
run: hf cache scan
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/

View File

@@ -10,6 +10,7 @@ ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src

View File

@@ -29,7 +29,7 @@
## 🎉 Latest Updates
- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3).
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
]
},
{
@@ -253,7 +253,6 @@
"source": [
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
"\n",
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
"set_pytorch_cuda_alloc_conf()"
]
},

48
examples/glm45/README.md Normal file
View File

@@ -0,0 +1,48 @@
# Finetune GLM4.5 with Axolotl
[UNSTABLE]
```bash
# LoRA SFT (4xH200 @ 84GB/GPU)
axolotl train examples/glm45/glm4.5-lora-fsdp2.yaml
# FFT SFT (4xH200)
# Checkpointing error on backward pass
# Without checkpointing => OOM
axolotl train examples/glm45/glm4.5-fft-fsdp2.yaml
```
## Dataset
In addition to normal OpenAI Messages format, GLM4.5 support an extra parameter for thinking in assistant section.
```json
{
"role": "assistant",
"reasoning_content": "...", // or have </think>...</think> in `content`
"content": "...",
}
```
Note:
- The role name for tools in this template is `tool`.
- You will see this Axolotl WARNING. This is to be as expected as the template does not use EOS.
```bash
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
```
- Make sure you set the below extra attributes if needed
```yaml
datasets:
- path: ...
type: chat_template
message_property_mappings:
role: role
content: content
# tool_calls: tool_calls # uncomment if using tools
# reasoning_content: reasoning_content # uncomment if have reasoning
# Uncomment if training on tool role (you would rarely if ever need this)
# eot_tokens:
# - <|observation|>
```

View File

@@ -0,0 +1,59 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
# gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
state_dict_type: SHARDED_STATE_DICT
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,74 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: lora
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
# gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
state_dict_type: SHARDED_STATE_DICT
reshard_after_forward: true
# activation_checkpointing: false

View File

@@ -32,6 +32,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -29,6 +29,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,6 +28,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,6 +41,10 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -29,7 +29,6 @@ flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true
wandb_project:

View File

@@ -30,7 +30,7 @@ eval_sample_packing: true
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 1
warmup_ratio: 0.1
warmup_steps: 0.1
optimizer: adamw_8bit
lr_scheduler: cosine
@@ -44,7 +44,7 @@ resume_from_checkpoint:
sdp_attention: true
logging_steps: 1
save_strategy: epoch
save_strategy: best
eval_strategy: epoch
special_tokens:

View File

@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -0,0 +1,50 @@
# Finetune Ministral with Axolotl
Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/ministral/ministral-small-qlora.yaml
```
This config uses about 8.76 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Tips
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-8B-Instruct-2410
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,79 @@
# Finetune Ministral3 with Axolotl
Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
Note: This is still experimental given it is based on transformers v5 RC.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Swap to the Axolotl transformers v5 branch
```bash
cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml
git fetch
git checkout transformers-v5
# Install packages for transformers v5
pip install -e .
```
4. Run the fine-tuning:
```bash
axolotl train ministral3-3b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### Tips
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### Thinking
Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
### Vision
Ministral3 2512 model also supports vision capabilities.
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,73 @@
# Ministral3 2512 Thinking Fine-tuning
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl (see [main README](../README.md))
## Getting Started
Run the thinking model fine-tuning:
```bash
axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml
```
This config uses about 4.76 GiB VRAM.
### Tips
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
## Dataset Format
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
Example format:
```json
{
"messages": [
{
"role": "system",
"content": [
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
]
},
{
"role": "user",
"content": [
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
]
},
{
"role": "assistant",
"content": [
{
"type": "thinking",
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
},
{
"type": "text",
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
}
]
}
]
}
```
### Advanced Options
The `thinking` section supports an optional `closed` parameter:
```json
{
"type": "thinking",
"thinking": "Internal reasoning here...",
"closed": true // Default: true, controls adding the closing [/THINK] tag
}
```

View File

@@ -0,0 +1,67 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: Nanobit/text-think-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,57 @@
# Ministral3 2512 Vision Fine-tuning
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.
## Prerequisites
Before starting, ensure you have:
- Installed Axolotl from source (see [main README](../README.md#getting-started))
## Getting started
1. Install the required vision lib:
```bash
pip install 'mistral-common[opencv]==1.8.6'
```
2. Download the example dataset image:
```bash
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
```
3. Run the fine-tuning:
```bash
axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml
```
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
### Tips
Key differences from text-only model:
- Multi-modal dataset format required
- Sample packing not supported
## Dataset Format
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
Example:
```json
{
"messages": [
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
{"role": "user", "content": [
{ "type": "text", "text": "What's in this image?"},
{"type": "image", "path": "path/to/image.jpg" }
]},
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
],
}
```
## Limitations
- Sample Packing is not supported for multi-modality training currently.

View File

@@ -0,0 +1,64 @@
base_model: mistralai/Ministral-3-3B-Reasoning-2512
processor_type: AutoProcessor
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# sample dataset below requires downloading image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,24 +6,16 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Install Cut Cross Entropy
python scripts/cutcrossentropy_install.py | sh
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
2. Run the finetuning example:
```bash
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### TIPS

View File

@@ -0,0 +1,67 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/out_gemma/
sequence_len: 8096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out_gemma/
sequence_len: 8096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_gemma/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_gemma/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_gemma27/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,73 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_gemma27/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-72B
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/out_math_72b/
sequence_len: 4096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-72B
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template
output_dir: ./outputs/qat_out_math_72b/
sequence_len: 4096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-72B
# Alpaca finetuning configuration for Qwen2.5-72B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/out_qwen72b/
sequence_len: 8096
sample_packing: true
flash_attention: true
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-72B
# Alpaca finetuning configuration for Qwen2.5-72B
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: qwen_25
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/qat_out_qwen72b/
sequence_len: 8096
sample_packing: true
flash_attention: true
qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
resume_from_checkpoint:
logging_steps: 1
# evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,70 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-adamw-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: adamw
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# AdamW optimizer settings (standard LR for AdamW)
optimizer: adamw_torch_fused
learning_rate: 0.0002
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

View File

@@ -0,0 +1,70 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-muon-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: muon
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# Muon optimizer settings
optimizer: muon
learning_rate: 0.02
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

46
examples/qwen3/README.md Normal file
View File

@@ -0,0 +1,46 @@
# Finetune Qwen3 with Axolotl
[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/qwen3/32b-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
### Chat template masking a few tokens off
If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.
```yaml
chat_template: qwen3
```
### TIPS
- For inference, please check the official model card as it depends on your reasoning mode.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,38 @@
# Finetune ArceeAI's Trinity with Axolotl
[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,67 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# CCE - N/A as of now
# plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
# flash_attention: true # Not supported
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -5,7 +5,7 @@ bitsandbytes==0.48.2
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.6.3
liger-kernel==0.6.4
# END section
packaging==23.2
@@ -20,15 +20,16 @@ deepspeed>=0.17.0
trl==0.25.0
hf_xet==1.2.0
kernels>=0.9.0
trackio
trackio>=0.13.0
typing_extensions>=4.14.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.49.1
gradio>=6.2.0,<7.0
modal==1.0.2
pydantic>=2.10.6
pydantic>=2.10.6,<2.12
addict
fire
PyYAML>=6.0
@@ -67,9 +68,8 @@ openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.5
axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.5
mistral-common==1.8.6

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
)

View File

@@ -66,7 +66,6 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]

View File

@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trackio_ import setup_trackio_env_vars
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -227,6 +228,7 @@ def load_cfg(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
@@ -245,6 +247,7 @@ def load_cfg(
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
setup_trackio_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
@@ -259,3 +262,11 @@ def load_cfg(
)
return cfg
def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
except RuntimeError:
return False

View File

@@ -288,8 +288,8 @@ def do_inference_gradio(
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
demo.launch(
footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -45,6 +45,7 @@ def cli():
print_axolotl_text_art()
load_dotenv()
set_pytorch_cuda_alloc_conf()
set_misc_env()
@cli.command()

View File

@@ -8,7 +8,7 @@ from typing import Union
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import (
TorchAOQuantDType,
@@ -66,6 +66,11 @@ def do_quantize(
LOG.info(f"Loading model from {model_path}.")
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
@@ -107,6 +112,10 @@ def do_quantize(
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
processor.save_pretrained(str(Path(output_dir) / "quantized"))
if hub_model_id:
hub_model_id = (
hub_model_id.rstrip("-")
@@ -114,6 +123,8 @@ def do_quantize(
)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")

View File

@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
outputs=[masked_preview, html_out],
)
demo.queue().launch(
show_api=False,
demo.launch(
footer_links=["gradio", "settings"],
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -14,7 +14,9 @@ MOE_ARCH_BLOCK = {
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"glm4_moe": "Glm4MoeMoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
}

View File

@@ -35,6 +35,7 @@ from axolotl.utils import (
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
is_trackio_available,
)
from axolotl.utils.callbacks import (
GCCallback,
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_trackio and is_trackio_available():
from axolotl.utils.callbacks.trackio_ import (
SaveAxolotlConfigtoTrackioCallback,
)
callbacks.append(
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
@@ -281,11 +290,22 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
from axolotl.contribs.mit.muon.dist_muon import (
DistMuonOptimizerFactory,
)
optimizer_cls = DistMuonOptimizerFactory
optimizer_kwargs["device_mesh"] = device_mesh
else:
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import (
@@ -423,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
if self.cfg.use_trackio:
report_to.append("trackio")
training_args_kwargs["report_to"] = report_to
@@ -430,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
elif self.cfg.use_trackio:
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
else:
training_args_kwargs["run_name"] = None

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import math
import os
from collections import defaultdict
from functools import partial, wraps
@@ -603,6 +604,7 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
@@ -613,7 +615,18 @@ class AxolotlTrainer(
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(fn(values).item(), 4)
logs[key] = round(fn(values).item(), metric_ndigits)
if "loss" in logs:
try:
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
except OverflowError:
logs["ppl"] = float("inf")
if "eval_loss" in logs:
try:
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
except OverflowError:
logs["eval_ppl"] = float("inf")
if is_main_process():
# Add memory usage
@@ -631,8 +644,10 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if hasattr(self.state, "total_tokens"):
if (
hasattr(self.state, "total_tokens")
and self.state.total_tokens is not None
):
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -36,4 +36,6 @@ class DPOStrategy:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
```
## Usage
@@ -44,6 +44,7 @@ plugins:
- gemma3n_text
- glm
- glm4
- glm_moe
- glm4_moe
- glm4v
- glm4v_moe
@@ -61,6 +62,8 @@ plugins:
- llama4
- llama4_text
- llava
- ministral
- ministral3
- mistral
- mistral3
- mixtral

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
)

View File

@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
if cfg.dense_mixer:
if not importlib.util.find_spec("densemixer"):
raise RuntimeError(
"DenseMixer is not installed. Install it with `pip install densemizer`"
"DenseMixer is not installed. Install it with `pip install densemixer`"
)
from densemixer.patching import (

View File

@@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
# let subclasses add fields before transform
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Hook for subclasses to prepare additional KD fields before transform
"""
return tokenized_prompt
@@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
target_token_ids = prompt.get("target_token_ids", None)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Add pre-tokenized target_token_ids for v2 format
"""
target_token_ids = original_prompt.pop("target_token_ids", None)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
return tokenized_prompt

View File

@@ -16,6 +16,8 @@
KD trainer
"""
from typing_extensions import override
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
@@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
if columns_to_add:
self._signature_columns += columns_to_add
@override
def compute_loss(
self,
model,
@@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
):
del inputs["attention_mask"]
if num_items_in_batch is None and "labels" in inputs:
num_items_in_batch = (inputs["labels"] != -100).sum().item()
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
if isinstance(outputs, dict):
loss = outputs["loss"]
elif isinstance(outputs, tuple):
loss = outputs[0]
else:
loss = outputs.loss if hasattr(outputs, "loss") else outputs
return (loss, outputs) if return_outputs else loss

View File

@@ -142,9 +142,12 @@ def load_lora(
):
setup_quantized_meta_for_peft(model)
model_kwargs: Any = {}
if cfg.peft_autocast_adapter_dtype is not None:
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {}
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
@@ -155,7 +158,7 @@ def load_lora(
**model_kwargs,
)
else:
model = get_peft_model(model, lora_config)
model = get_peft_model(model, lora_config, **model_kwargs)
if rank == 0:
try:

View File

@@ -37,6 +37,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"deepseek_v3",
"glm",
"glm4",
"glm4_moe",
"smollm3",
"granite",
"granitemoe",
@@ -52,6 +53,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"olmo",
"olmo2",
"olmo3",
"ministral",
"ministral3",
"afmoe",
]

View File

@@ -95,6 +95,7 @@ class ChatTemplatePrompter(Prompter):
add_generation_prompt=False,
images=None,
tools=None,
real_last_index=None,
):
"""
Build a prompt from a conversation.
@@ -114,6 +115,9 @@ class ChatTemplatePrompter(Prompter):
if tools:
chat_template_kwargs["tools"] = tools
if real_last_index:
chat_template_kwargs["real_last_index"] = real_last_index
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
@@ -631,11 +635,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
turns_with_empty = turns[:turn_idx] + [empty_turn]
turns_with_content = turns[: turn_idx + 1]
real_last_index = len(turns) - 1
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
dummy_ids = self.prompter.build_prompt(
turns_with_empty, tools=tools, real_last_index=real_last_index
) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
full_ids = self.prompter.build_prompt(
turns_with_content, tools=tools, real_last_index=real_last_index
) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")

View File

@@ -24,6 +24,10 @@ def is_opentelemetry_available():
)
def is_trackio_available():
return importlib.util.find_spec("trackio") is not None
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).
@@ -41,14 +45,27 @@ def get_pytorch_version() -> tuple[int, int, int]:
def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
"""Set up CUDA allocation config"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)
config_value = "expandable_segments:True,roundup_power2_divisions:16"
if (
torch_major == 2
and torch_minor >= 9
and os.getenv("PYTORCH_ALLOC_CONF") is None
):
os.environ["PYTORCH_ALLOC_CONF"] = config_value
elif (
torch_major == 2
and torch_minor >= 2
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
def set_misc_env():
if os.getenv("XFORMERS_IGNORE_FLASH_VERSION_CHECK") is None:
os.environ["XFORMERS_IGNORE_FLASH_VERSION_CHECK"] = "1"
def get_not_null(value, default=None):

View File

@@ -0,0 +1,44 @@
"""Trackio module for trainer callbacks"""
from typing import TYPE_CHECKING
import trackio
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
from axolotl.utils.environment import is_package_version_ge
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
"""Callback for trackio integration"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments",
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if is_main_process():
try:
if not is_package_version_ge("trackio", "0.11.0"):
LOG.warning(
"Trackio version 0.11.0 or higher is required to save config files. "
"Please upgrade trackio: pip install --upgrade trackio"
)
return control
trackio.save(self.axolotl_config_path)
LOG.info("The Axolotl config has been saved to Trackio.")
except (FileNotFoundError, ConnectionError, AttributeError) as err:
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
return control

View File

@@ -15,6 +15,12 @@
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{#- Determine the real last index: use provided value or default to messages length - 1 #}
{%- if real_last_index is defined and real_last_index is not none %}
{%- set ns.real_last_index = real_last_index %}
{%- else %}
{%- set ns.real_last_index = messages|length - 1 %}
{%- endif %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
@@ -37,7 +43,7 @@
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}

View File

@@ -203,6 +203,7 @@ def wrap_streaming_dataset(
max_seq_length=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
multipack_attn=multipack_attn,
bin_size=cfg.sample_packing_bin_size,
)
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
@@ -254,6 +255,7 @@ def encode_packed_streaming(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
bin_size: int,
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
@@ -278,6 +280,7 @@ def encode_packed_streaming(
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
bin_size=bin_size,
)
chunked_data = defaultdict(list)

View File

@@ -188,7 +188,10 @@ def handle_long_seq_in_dataset(
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Filtered dataset with long sequences removed.
Filtered dataset with long sequences handled according to the excess_length_strategy value:
'drop' (default) excludes any sequence longer than sequence_len
'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
"""
if (
hasattr(dataset, "column_names")
@@ -206,10 +209,13 @@ def handle_long_seq_in_dataset(
)
return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
with contextlib.suppress(AttributeError):
@@ -228,9 +234,13 @@ def handle_long_seq_in_dataset(
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
)
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,

View File

@@ -80,6 +80,9 @@ class HFMistralTokenizer(MistralCommonTokenizer):
) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None)
try:
if add_generation_prompt:
self._set_mode(ValidationMode.serving)
@@ -218,3 +221,10 @@ class HFMistralTokenizer(MistralCommonTokenizer):
model_input_names=model_input_names,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]:
"""
Patches to remove save_jinja_files from being passed onwards.
"""
kwargs.pop("save_jinja_files", None)
return super().save_pretrained(*args, **kwargs)

View File

@@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler):
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
bin_size: int, # The max number of samples that can be packed in a single bin
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 4, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
mp_start_method: str = "fork",
@@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler):
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
bin_size=self.bin_size or self.batch_max_len,
num_processes=min(4, num_processes) if num_processes else 4,
safe_mode=self.safe_mode,
mp_start_method=self.mp_start_method,

View File

@@ -2,6 +2,7 @@
from typing import Annotated, Any, Literal
from accelerate.utils import is_fp8_available
from annotated_types import MinLen
from packaging import version
from pydantic import (
@@ -33,6 +34,7 @@ from axolotl.utils.schemas.integrations import (
MLFlowConfig,
OpenTelemetryConfig,
RayConfig,
TrackioConfig,
WandbConfig,
)
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
@@ -62,6 +64,7 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
CometConfig,
TrackioConfig,
OpenTelemetryConfig,
LISAConfig,
GradioConfig,
@@ -173,6 +176,12 @@ class AxolotlInputConfig(
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_use_liger_kernel: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."},
)
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
@@ -445,10 +454,10 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(
@@ -1092,6 +1101,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="after")
def check_fp8(self):
if self.fp8 and not self.capabilities.fp8:
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
raise ValueError(
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):

View File

@@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel):
"description": "Port for the Prometheus metrics HTTP server"
},
)
class TrackioConfig(BaseModel):
"""Trackio configuration subset"""
use_trackio: bool | None = None
trackio_project_name: str | None = Field(
default=None,
json_schema_extra={"description": "Your trackio project name"},
)
trackio_run_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your trackio run"},
)
trackio_space_id: str | None = Field(
default=None,
json_schema_extra={
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
},
)

View File

@@ -109,6 +109,12 @@ class LoraConfig(BaseModel):
)
},
)
peft_autocast_adapter_dtype: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT."
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,

View File

@@ -751,12 +751,19 @@ class OptimizationValidationMixin:
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
if data.get("optimizer") == "muon" and (
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
)
if data.get("optimizer") == "muon":
if data.get("deepspeed"):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed"
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
if str(fsdp_version) != "2":
raise ValueError(
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
)
return data
@model_validator(mode="before")
@@ -840,40 +847,6 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
@@ -975,6 +948,40 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
class SystemValidationMixin:
"""Validation methods related to system and hardware configuration."""

View File

@@ -0,0 +1,17 @@
"""Module for trackio utilities"""
import os
from axolotl.utils.dict import DictDefault
def setup_trackio_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("trackio_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
cfg.use_trackio = True

View File

@@ -205,12 +205,15 @@ def add_length(sample):
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
If raise_on_drop is set, the code raises a ValueError if a sample is
encountered that is too long and would have been dropped.
"""
min_sequence_len = min_sequence_len or 2
@@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
results.append(min_sequence_len <= length <= sequence_len)
return results

View File

@@ -474,10 +474,8 @@ def rand_reward_func(prompts, completions) -> list[float]:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
from axolotl.contribs.mit.muon import MuonOptimizerFactory
from axolotl.contribs.mit.muon.muon import Muon
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
@@ -556,10 +554,8 @@ class TestHFCausalTrainerBuilder:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
from axolotl.contribs.mit.muon import MuonOptimizerFactory
from axolotl.contribs.mit.muon.muon import Muon
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory

View File

@@ -0,0 +1,168 @@
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert len(checkpoint_files) > 0, (
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestDistMuon:
"""Test class for DistMuon optimizer with FSDP2 functionality."""
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_lora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -0,0 +1,81 @@
"""
Test for KD chat template strategies
"""
from unittest.mock import Mock
import pytest
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
class TestChatTemplateStrategyWithKDv2:
"""Test v2 strategy correctly handles target_token_ids"""
@pytest.fixture
def v2_strategy(self):
"""Create v2 strategy instance with mocked dependencies"""
# Mock prompter
mock_prompter = Mock()
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
mock_prompter.chat_template_msg_variables = ["role", "content"]
mock_prompter.chat_template = "{{ messages }}"
# Mock tokenizer
mock_tokenizer = Mock()
mock_tokenizer.pad_token_id = 0
mock_tokenizer.eos_token_id = 2
mock_tokenizer.bos_token_id = 1
mock_tokenizer.eos_token = "<|endoftext|>"
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
mock_tokenizer.encode = Mock(return_value=[2])
return ChatTemplateStrategyWithKDv2(
prompter=mock_prompter,
tokenizer=mock_tokenizer,
train_on_inputs=False,
sequence_len=512,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
)
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
"""
Test that v2's _prepare_kd_fields hook adds target_token_ids.
Validates the Template Method pattern fix where v2 overrides
the hook to add target_token_ids before transform.
"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {"target_token_ids": [[10, 20], [30, 40]]}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" in result
assert result["target_token_ids"] == [[10, 20], [30, 40]]
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
"""Test hook handles missing target_token_ids gracefully"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" not in result
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
"""
Test v2's transform fails without target_token_ids.
Validates the bug fix - transform expects target_token_ids
to be added by the hook.
"""
sample = {
"input_ids": [1, 10, 20, 30, 2],
"labels": [1, 10, 20, 30, 2],
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
}
with pytest.raises(KeyError, match="target_token_ids"):
v2_strategy.transform_logprobs(sample)

View File

@@ -7,6 +7,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase):
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
def test_excess_length_strategy(self):
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
# This should work
data = {
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
self.assertEqual(dropped, [True, False])
if __name__ == "__main__":
unittest.main()

View File

@@ -13,7 +13,9 @@ from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.data.sft import (
_load_tokenized_prepared_datasets,
)
from axolotl.utils.dict import DictDefault
from tests.constants import (

View File

@@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation):
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
validate_config(cfg)

View File

@@ -123,6 +123,17 @@ class TestFSDPValidation:
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
assert cfg.fsdp_config.reshard_after_forward is True
def test_muon_fsdp1_rejected(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
optimizer="muon",
fsdp_version=1,
fsdp_config={"reshard_after_forward": True},
)
with pytest.raises(
ValueError, match="Muon optimizer is only compatible with FSDP2"
):
validate_config(cfg)
@pytest.mark.parametrize(
"rl",
[