From b234532d9f02e232c973aa1cf6d137530b0b8d27 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 28 Nov 2025 18:54:48 +0700 Subject: [PATCH 01/24] Feat: add peft_ensure_weight_tying (#3278) * feat: upgrade peft to 0.18.0 * feat: add peft_ensure_weight_tying * fix: default * chore: adjust kwarg per feedback --- requirements.txt | 2 +- src/axolotl/loaders/adapter.py | 2 ++ src/axolotl/utils/schemas/peft.py | 9 +++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 08759279d..f020aaffc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ liger-kernel==0.6.3 packaging==23.2 huggingface_hub>=0.36.0 -peft>=0.17.1 +peft>=0.18.0 tokenizers>=0.22.1 transformers==4.57.1 accelerate==1.11.0 diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 8e8177b62..dca688bb2 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -102,6 +102,8 @@ def load_lora( lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication if cfg.peft_trainable_token_indices: lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices + if cfg.peft_ensure_weight_tying is not None: + lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying # Determine the correct PEFT task type model_cls = type(model).__name__ diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index af22913fd..fd16dec3f 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -100,6 +100,15 @@ class LoraConfig(BaseModel): ) }, ) + peft_ensure_weight_tying: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Whether to tie adapter weights for tied model weights. " + "See https://github.com/huggingface/peft/issues/2864" + ) + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False, From 7fb6a947d9411867f22de9b18ee4d42acc76398f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:03:14 +0700 Subject: [PATCH 02/24] chore: update pre-commit hooks (#3287) Co-authored-by: SalmanMohammadi <25081738+SalmanMohammadi@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86d8927d2..3500bb0aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,13 +11,13 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.3 + rev: v0.14.7 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.18.2 + rev: v1.19.0 hooks: - id: mypy additional_dependencies: @@ -26,7 +26,7 @@ repos: 'pydantic>=2.5.3', ] - repo: https://github.com/PyCQA/bandit - rev: 1.8.6 + rev: 1.9.2 hooks: - id: bandit args: [ From c6ddcdd06ad352c2e2d770fc7e222b9c612d5755 Mon Sep 17 00:00:00 2001 From: Yohan Na Date: Mon, 1 Dec 2025 17:52:45 +0900 Subject: [PATCH 03/24] feat: add exaone4 chat template and update enums (#3279) * feat: add exaone4 chat template and update enums * fix: handle first message as system or tools in exaone4 chat template * Update src/axolotl/utils/chat_templates/templates/exaone4.jinja Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * fix: lint --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: NanoCode012 --- .../chat_templates/templates/exaone4.jinja | 126 ++++++++++++++++++ src/axolotl/utils/schemas/enums.py | 1 + 2 files changed, 127 insertions(+) create mode 100644 src/axolotl/utils/chat_templates/templates/exaone4.jinja diff --git a/src/axolotl/utils/chat_templates/templates/exaone4.jinja b/src/axolotl/utils/chat_templates/templates/exaone4.jinja new file mode 100644 index 000000000..8bfb0651b --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/exaone4.jinja @@ -0,0 +1,126 @@ +{%- if not skip_think is defined %} + {%- set skip_think = true %} +{%- endif %} +{%- set role_indicators = { + 'user': '[|user|]\n', + 'assistant': '[|assistant|]\n', + 'system': '[|system|]\n', + 'tool': '[|tool|]\n' +} %} +{%- set end_of_turn = '[|endofturn|]\n' %} +{%- macro available_tools(tools) %} + {{- "# Available Tools" }} + {{- "\nYou can use none, one, or multiple of the following tools by calling them as functions to help with the user’s query." }} + {{- "\nHere are the tools available to you in JSON format within and tags:\n" }} + {%- for tool in tools %} + {{- "" }} + {{- tool | tojson(ensure_ascii=False) | safe }} + {{- "\n" }} + {%- endfor %} + {{- "\nFor each function call you want to make, return a JSON object with function name and arguments within and tags, like:" }} + {{- "\n{\"name\": function_1_name, \"arguments\": {argument_1_name: argument_1_value, argument_2_name: argument_2_value}}" }} + {{- "\n{\"name\": function_2_name, \"arguments\": {...}}\n..." }} + {{- "\nNote that if no argument name is specified for a tool, you can just print the argument value directly, without the argument name or JSON formatting." }} +{%- endmacro %} +{%- set ns = namespace(last_query_index = messages|length - 1) %} +{%- for message in messages %} + {%- if message.role == "user" and message.content is string %} + {%- set ns.last_query_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{%- for i in range(messages | length) %} + {%- set msg = messages[i] %} + {%- set role = msg.role %} + {%- if role not in role_indicators %} + {{- raise_exception('Unknown role: ' ~ role) }} + {%- endif %} + {# ---- Case A: If the first message is "system", handle it here alone (without continue) ---- #} + {%- if i == 0 and role == 'system' %} + {{- role_indicators['system'] }} + {{- msg.content }} + {%- if tools is defined and tools %} + {{- "\n\n" }}{{- available_tools(tools) }} + {%- endif %} + {{- end_of_turn -}} + {%- else %} + {# ---- Case B: If the first message is tools instead of system, inject the system tools preamble ---- #} + {%- if i == 0 and tools is defined and tools %} + {{- role_indicators['system'] }} + {{- available_tools(tools) }} + {{- end_of_turn -}} + {%- endif %} + {%- endif %} + {%- if role == 'assistant' %} + {{- role_indicators['assistant'] }} + {%- if msg.content %} + {%- if "" in msg.content %} + {%- set content = msg.content.split('')[-1].strip() %} + {%- set reasoning_content = msg.content.split('')[0].strip() %} + {%- if reasoning_content.startswith("") %} + {%- set reasoning_content = reasoning_content[7:].strip() %} + {%- endif %} + {%- else %} + {%- set content = msg.content %} + {%- endif %} + {%- if msg.reasoning_content %} + {%- set reasoning_content = msg.reasoning_content %} + {%- endif %} + {%- if (not skip_think and loop.last) and reasoning_content is defined %} + {{- "\n" }} + {{- reasoning_content}} + {{- "\n\n\n" }} + {%- else %} + {{- "\n\n\n\n" }} + {%- endif %} + {{- content }} + {%- endif %} + {%- if msg.tool_calls %} + {%- if msg.content %} + {{- "\n" }} + {%- else %} + {{- "\n\n\n\n" }} + {%- endif %} + {%- for tool_call in msg.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- set arguments = tool_call.arguments %} + {%- elif tool_call.parameters is defined %} + {%- set arguments = tool_call.parameters %} + {%- else %} + {{- raise_exception('arguments or parameters are mandatory: ' ~ tool_call) }} + {%- endif %} + {{- "" }}{"name": "{{- tool_call.name }}", "arguments": {{ arguments | tojson(ensure_ascii=False) | safe }}}{{- "" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- end_of_turn -}} + {%- elif role == "tool" %} + {%- if i == 0 or messages[i - 1].role != "tool" %} + {{- role_indicators['tool'] }} + {%- endif %} + {%- if msg.content is defined %} + {{- "" }}{"result": {{ msg.content | tojson(ensure_ascii=False) | safe }}}{{- "" }} + {%- endif %} + {%- if loop.last or messages[i + 1].role != "tool" %} + {{- end_of_turn -}} + {%- else %} + {{- "\n" }} + {%- endif %} + {%- else %} + {{- role_indicators[role] }} + {{- msg.content }} + {{- end_of_turn -}} + {%- endif %} +{% endfor %} +{%- if add_generation_prompt %} + {{- role_indicators['assistant'] }} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- "\n" }} + {%- else %} + {{- "\n\n\n\n" }} + {%- endif %} +{%- endif %} diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index bcd03e1a2..f86d1a191 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -58,6 +58,7 @@ class ChatTemplate(str, Enum): falcon_h1 = "falcon_h1" tokenizer_default = "tokenizer_default" exaone = "exaone" + exaone4 = "exaone4" metharme = "metharme" pixtral = "pixtral" llava = "llava" From 4a0f98e612b0b0b7c70660f5b32b9560c1956a40 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 2 Dec 2025 21:16:23 +0700 Subject: [PATCH 04/24] feat: upgrade liger to 0.6.4 (#3289) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f020aaffc..21c94a3c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ bitsandbytes==0.48.2 triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 -liger-kernel==0.6.3 +liger-kernel==0.6.4 # END section packaging==23.2 From 86d8cca1494ba4e1872c846aa2afd83cedbfd618 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 3 Dec 2025 01:12:55 +0700 Subject: [PATCH 05/24] Feat: add trinity by ArceeAI (#3292) --- examples/trinity/README.md | 38 +++++++++++ .../trinity/trinity-nano-preview-qlora.yaml | 67 +++++++++++++++++++ src/axolotl/common/architectures.py | 1 + src/axolotl/monkeypatch/multipack.py | 1 + 4 files changed, 107 insertions(+) create mode 100644 examples/trinity/README.md create mode 100644 examples/trinity/trinity-nano-preview-qlora.yaml diff --git a/examples/trinity/README.md b/examples/trinity/README.md new file mode 100644 index 000000000..28b2e2b52 --- /dev/null +++ b/examples/trinity/README.md @@ -0,0 +1,38 @@ +# Finetune ArceeAI's Trinity with Axolotl + +[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build). + +2. Run the finetuning example: + + ```bash + axolotl train examples/trinity/trinity-nano-preview-qlora.yaml + ``` + +This config uses about 24.9 GiB VRAM. + +Let us know how it goes. Happy finetuning! πŸš€ + +### TIPS + +- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/trinity/trinity-nano-preview-qlora.yaml b/examples/trinity/trinity-nano-preview-qlora.yaml new file mode 100644 index 000000000..43263cabd --- /dev/null +++ b/examples/trinity/trinity-nano-preview-qlora.yaml @@ -0,0 +1,67 @@ +base_model: arcee-ai/Trinity-Nano-Preview +trust_remote_code: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# CCE - N/A as of now +# plugins: +# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +# flash_attention: true # Not supported +sdp_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index c8a2f0836..f4d6ca928 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -17,4 +17,5 @@ MOE_ARCH_BLOCK = { "deepseek_v3": "DeepseekV3MoE", "gpt_oss": "GptOssDecoderLayer", "lfm2_moe": "Lfm2MoeSparseMoeBlock", + "afmoe": "AfmoeMoE", } diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index fdda3c3bc..9642b1edb 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -52,6 +52,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "olmo", "olmo2", "olmo3", + "afmoe", ] From 2b66ee189c19a659e3aea24473388fd98b522b47 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 4 Dec 2025 20:32:08 +0700 Subject: [PATCH 06/24] Feat: add ministral3 (#3297) * feat: add ministral and mistral3 * chore: lint * feat: update cce for ministral * fix: add vram usage * feat: update for release * fix: save_pretrained issue in v5 * fix: add instructions to use v5 branch * fix: add to multipack * fix: improve instructions * fix: add model to readme --- README.md | 2 +- .../colab-axolotl-example.ipynb | 2 +- examples/magistral/README.md | 2 +- examples/ministral/README.md | 58 +++++++++++ examples/ministral/ministral-small-qlora.yaml | 67 +++++++++++++ examples/ministral/think/README.md | 99 +++++++++++++++++++ .../think/ministral3-small-think-qlora.yaml | 67 +++++++++++++ examples/olmo3/README.md | 20 ++-- scripts/cutcrossentropy_install.py | 2 +- .../integrations/cut_cross_entropy/README.md | 4 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/monkeypatch/multipack.py | 2 + .../utils/mistral/mistral_tokenizer.py | 7 ++ 13 files changed, 314 insertions(+), 20 deletions(-) create mode 100644 examples/ministral/README.md create mode 100644 examples/ministral/ministral-small-qlora.yaml create mode 100644 examples/ministral/think/README.md create mode 100644 examples/ministral/think/ministral3-small-think-qlora.yaml diff --git a/README.md b/README.md index 1517fb874..13518f2a8 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ ## πŸŽ‰ Latest Updates -- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3). +- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral). - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 57a638948..06705eb3d 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\"" ] }, { diff --git a/examples/magistral/README.md b/examples/magistral/README.md index a09138744..40a793f10 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these Here is an example of how to install from pip: ```bash -# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +# Ensure you have Pytorch installed (Pytorch 2.7.0 min) pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' ``` diff --git a/examples/ministral/README.md b/examples/ministral/README.md new file mode 100644 index 000000000..b088c06ec --- /dev/null +++ b/examples/ministral/README.md @@ -0,0 +1,58 @@ +# Finetune Ministral with Axolotl + +Ministral is a family of openweight models from MistralAI found on HuggingFace at [2410](mistralai/Ministral-8B-Instruct-2410) and [2512](https://huggingface.co/collections/mistralai/ministral-3) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + + ```bash + axolotl train examples/ministral/ministral-small-qlora.yaml + ``` + +This config uses about 8.76 GiB VRAM. + +Let us know how it goes. Happy finetuning! πŸš€ + +### Thinking + +MistralAI has released their [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps. + +πŸ“š **[See the Thinking fine-tuning guide β†’](./think/README.md)** + +For Ministral3 Base/Instruct, you can reuse the above config to train supervised finetuning. + +### Tips + +- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/ministral/ministral-small-qlora.yaml b/examples/ministral/ministral-small-qlora.yaml new file mode 100644 index 000000000..0d5300ef6 --- /dev/null +++ b/examples/ministral/ministral-small-qlora.yaml @@ -0,0 +1,67 @@ +base_model: mistralai/Ministral-8B-Instruct-2410 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/ministral/think/README.md b/examples/ministral/think/README.md new file mode 100644 index 000000000..0ee5ea876 --- /dev/null +++ b/examples/ministral/think/README.md @@ -0,0 +1,99 @@ +# Ministral3 2512 Thinking Fine-tuning + +This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections. + +Thanks to the team at MistralAI for giving us early access to prepare for these releases. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl (see [main README](../README.md)) + +## Getting Started + +1. Install transformers v5 + + ```bash + pip install transformers==5.0.0rc0 + ``` + + Note: This is still experimental in Axolotl. Other stuff may break. + +2. Upgrade `mistral-common` + + ```bash + pip install mistral-common==1.8.6 + ``` + +3. Swap to the Axolotl transformers v5 branch + + ```bash + # copy examples/ministral/think/ministral3-small-think-qlora.yaml somewhere + cp examples/ministral/think/ministral3-small-think-qlora.yaml ministral3-small-think-qlora.yaml + + git fetch + git checkout transformers-v5 + ``` + +4. Run the thinking model fine-tuning: + + ```bash + axolotl train ministral3-small-think-qlora.yaml + ``` + +This config uses about 4.76 GiB VRAM. + +### Tips + +- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below. +- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent. + +## Dataset Format + +The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages. + +Example format: + +```json +{ + "messages": [ + { + "role": "system", + "content": [ + { "type": "text", "text": "{SYSTEM_PROMPT}"} + ] + }, + { + "role": "user", + "content": [ + { "type": "text", "text": "Solve this step by step: What is 15% of 240?"} + ] + }, + { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 Γ— 240 = 36." + }, + { + "type": "text", + "text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 Γ— 0.15 = 36\n\nTherefore, 15% of 240 is 36." + } + ] + } + ] +} +``` + +### Advanced Options + +The `thinking` section supports an optional `closed` parameter: + +```json +{ + "type": "thinking", + "thinking": "Internal reasoning here...", + "closed": true // Default: true, controls adding the closing [/THINK] tag +} +``` diff --git a/examples/ministral/think/ministral3-small-think-qlora.yaml b/examples/ministral/think/ministral3-small-think-qlora.yaml new file mode 100644 index 000000000..987c0bd54 --- /dev/null +++ b/examples/ministral/think/ministral3-small-think-qlora.yaml @@ -0,0 +1,67 @@ +base_model: mistralai/Ministral-3-3B-Reasoning-2512 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: Nanobit/text-think-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/olmo3/README.md b/examples/olmo3/README.md index d4dbe05a9..2f98eb73e 100644 --- a/examples/olmo3/README.md +++ b/examples/olmo3/README.md @@ -6,24 +6,16 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations ## Getting started -1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: - Here is an example of how to install from pip: ```bash - # Ensure you have a compatible version of Pytorch installed - pip3 install packaging setuptools wheel ninja - pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0' - - # Install Cut Cross Entropy - python scripts/cutcrossentropy_install.py | sh + axolotl train examples/olmo3/olmo3-7b-qlora.yaml ``` -2. Run the finetuning example: - -```bash -axolotl train examples/olmo3/olmo3-7b-qlora.yaml -``` - Let us know how it goes. Happy finetuning! πŸš€ ### TIPS diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 91d0f45d6..ec5c6d475 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 4f98ac089..2c5b0f6e5 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88" ``` ## Usage @@ -61,6 +61,8 @@ plugins: - llama4 - llama4_text - llava +- ministral +- ministral3 - mistral - mistral3 - mixtral diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index b8f7e9da3..98a1659b1 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`' ) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 9642b1edb..6a6b935be 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -52,6 +52,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "olmo", "olmo2", "olmo3", + "ministral", + "ministral3", "afmoe", ] diff --git a/src/axolotl/utils/mistral/mistral_tokenizer.py b/src/axolotl/utils/mistral/mistral_tokenizer.py index 0414ece78..af174cdac 100644 --- a/src/axolotl/utils/mistral/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral/mistral_tokenizer.py @@ -218,3 +218,10 @@ class HFMistralTokenizer(MistralCommonTokenizer): model_input_names=model_input_names, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) + + def save_pretrained(self, *args, **kwargs) -> tuple[str, ...]: + """ + Patches to remove save_jinja_files from being passed onwards. + """ + kwargs.pop("save_jinja_files", None) + return super().save_pretrained(*args, **kwargs) From 5992e607a2e59dced8b0ccb520527b1ad57c94f7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 4 Dec 2025 21:44:44 +0700 Subject: [PATCH 07/24] fix: improve ministral3 docs to be clearer (#3300) * fix: improve ministral3 docs to be clearer * fix: title * chore: wording --- examples/ministral/README.md | 10 +-- examples/ministral3/README.md | 79 +++++++++++++++++++ examples/ministral3/ministral3-3b-qlora.yaml | 67 ++++++++++++++++ .../{ministral => ministral3}/think/README.md | 34 +------- .../think/ministral3-3b-think-qlora.yaml} | 0 examples/ministral3/vision/README.md | 57 +++++++++++++ .../vision/ministral3-3b-vision-qlora.yml | 64 +++++++++++++++ 7 files changed, 272 insertions(+), 39 deletions(-) create mode 100644 examples/ministral3/README.md create mode 100644 examples/ministral3/ministral3-3b-qlora.yaml rename examples/{ministral => ministral3}/think/README.md (72%) rename examples/{ministral/think/ministral3-small-think-qlora.yaml => ministral3/think/ministral3-3b-think-qlora.yaml} (100%) create mode 100644 examples/ministral3/vision/README.md create mode 100644 examples/ministral3/vision/ministral3-3b-vision-qlora.yml diff --git a/examples/ministral/README.md b/examples/ministral/README.md index b088c06ec..f8af7bf27 100644 --- a/examples/ministral/README.md +++ b/examples/ministral/README.md @@ -1,6 +1,6 @@ # Finetune Ministral with Axolotl -Ministral is a family of openweight models from MistralAI found on HuggingFace at [2410](mistralai/Ministral-8B-Instruct-2410) and [2512](https://huggingface.co/collections/mistralai/ministral-3) (see [Thinking](#thinking)). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. +Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. ## Getting started @@ -18,14 +18,6 @@ This config uses about 8.76 GiB VRAM. Let us know how it goes. Happy finetuning! πŸš€ -### Thinking - -MistralAI has released their [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps. - -πŸ“š **[See the Thinking fine-tuning guide β†’](./think/README.md)** - -For Ministral3 Base/Instruct, you can reuse the above config to train supervised finetuning. - ### Tips - We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`. diff --git a/examples/ministral3/README.md b/examples/ministral3/README.md new file mode 100644 index 000000000..6ed7efda5 --- /dev/null +++ b/examples/ministral3/README.md @@ -0,0 +1,79 @@ +# Finetune Ministral3 with Axolotl + +Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning. + +Thanks to the team at MistralAI for giving us early access to prepare for these releases. + +Note: This is still experimental given it is based on transformers v5 RC. + +## Getting started + +1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Swap to the Axolotl transformers v5 branch + + ```bash + cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml + + git fetch + git checkout transformers-v5 + + # Install packages for transformers v5 + pip install -e . + ``` + +4. Run the fine-tuning: + + ```bash + axolotl train ministral3-3b-qlora.yaml + ``` + +Let us know how it goes. Happy finetuning! πŸš€ + + +### Tips + +- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +### Thinking + +Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps. + +πŸ“š **[See the Thinking fine-tuning guide β†’](./think/README.md)** + +### Vision + +Ministral3 2512 model also supports vision capabilities. + +πŸ“š **[See the Vision fine-tuning guide β†’](./vision/README.md)** + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/ministral3/ministral3-3b-qlora.yaml b/examples/ministral3/ministral3-3b-qlora.yaml new file mode 100644 index 000000000..a31545ab2 --- /dev/null +++ b/examples/ministral3/ministral3-3b-qlora.yaml @@ -0,0 +1,67 @@ +base_model: mistralai/Ministral-3-3B-Reasoning-2512 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/ministral/think/README.md b/examples/ministral3/think/README.md similarity index 72% rename from examples/ministral/think/README.md rename to examples/ministral3/think/README.md index 0ee5ea876..8c40adbb9 100644 --- a/examples/ministral/think/README.md +++ b/examples/ministral3/think/README.md @@ -2,8 +2,6 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections. -Thanks to the team at MistralAI for giving us early access to prepare for these releases. - ## Prerequisites Before starting, ensure you have: @@ -11,35 +9,11 @@ Before starting, ensure you have: ## Getting Started -1. Install transformers v5 +Run the thinking model fine-tuning: - ```bash - pip install transformers==5.0.0rc0 - ``` - - Note: This is still experimental in Axolotl. Other stuff may break. - -2. Upgrade `mistral-common` - - ```bash - pip install mistral-common==1.8.6 - ``` - -3. Swap to the Axolotl transformers v5 branch - - ```bash - # copy examples/ministral/think/ministral3-small-think-qlora.yaml somewhere - cp examples/ministral/think/ministral3-small-think-qlora.yaml ministral3-small-think-qlora.yaml - - git fetch - git checkout transformers-v5 - ``` - -4. Run the thinking model fine-tuning: - - ```bash - axolotl train ministral3-small-think-qlora.yaml - ``` +```bash +axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml +``` This config uses about 4.76 GiB VRAM. diff --git a/examples/ministral/think/ministral3-small-think-qlora.yaml b/examples/ministral3/think/ministral3-3b-think-qlora.yaml similarity index 100% rename from examples/ministral/think/ministral3-small-think-qlora.yaml rename to examples/ministral3/think/ministral3-3b-think-qlora.yaml diff --git a/examples/ministral3/vision/README.md b/examples/ministral3/vision/README.md new file mode 100644 index 000000000..369b0116a --- /dev/null +++ b/examples/ministral3/vision/README.md @@ -0,0 +1,57 @@ +# Ministral3 2512 Vision Fine-tuning + +This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl. + +## Prerequisites + +Before starting, ensure you have: +- Installed Axolotl from source (see [main README](../README.md#getting-started)) + +## Getting started + +1. Install the required vision lib: + ```bash + pip install 'mistral-common[opencv]==1.8.6' + ``` + +2. Download the example dataset image: + ```bash + wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg + ``` + +3. Run the fine-tuning: + ```bash + axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml + ``` + +WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look. + +### Tips + +Key differences from text-only model: +- Multi-modal dataset format required +- Sample packing not supported + +## Dataset Format + +The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format). + +One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now. + +Example: +```json +{ + "messages": [ + {"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]}, + {"role": "user", "content": [ + { "type": "text", "text": "What's in this image?"}, + {"type": "image", "path": "path/to/image.jpg" } + ]}, + {"role": "assistant", "content": [{ "type": "text", "text": "..." }]}, + ], +} +``` + +## Limitations + +- Sample Packing is not supported for multi-modality training currently. diff --git a/examples/ministral3/vision/ministral3-3b-vision-qlora.yml b/examples/ministral3/vision/ministral3-3b-vision-qlora.yml new file mode 100644 index 000000000..0a0fdce4a --- /dev/null +++ b/examples/ministral3/vision/ministral3-3b-vision-qlora.yml @@ -0,0 +1,64 @@ +base_model: mistralai/Ministral-3-3B-Reasoning-2512 +processor_type: AutoProcessor + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_4bit: true + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +# sample dataset below requires downloading image in advance +# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg +datasets: + - path: Nanobit/text-vision-2k-test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./outputs/out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config From 75b20fb66f4b37e01217e01dc6fb7ef40ff5227f Mon Sep 17 00:00:00 2001 From: salman Date: Sat, 6 Dec 2025 16:27:18 +0000 Subject: [PATCH 08/24] Save processor in quantizer CLI (#3290) --- src/axolotl/cli/quantize.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index c11bcc6d9..f4fcc6d7d 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -8,7 +8,7 @@ from typing import Union from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig from axolotl.cli.config import load_cfg -from axolotl.loaders import load_tokenizer +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.logging import get_logger from axolotl.utils.quantization import ( TorchAOQuantDType, @@ -66,6 +66,11 @@ def do_quantize( LOG.info(f"Loading model from {model_path}.") tokenizer = load_tokenizer(cfg) + + processor = None + if cfg.is_multimodal: + processor = load_processor(cfg, tokenizer) + config = AutoConfig.from_pretrained(model_path) torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None model = AutoModelForCausalLM.from_pretrained( @@ -107,6 +112,10 @@ def do_quantize( save_jinja_files=cfg.tokenizer_save_jinja_files, ) + if processor: + LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.") + processor.save_pretrained(str(Path(output_dir) / "quantized")) + if hub_model_id: hub_model_id = ( hub_model_id.rstrip("-") @@ -114,6 +123,8 @@ def do_quantize( ) model.push_to_hub(hub_model_id, safe_serialization=False) tokenizer.push_to_hub(hub_model_id) + if processor: + processor.push_to_hub(hub_model_id) LOG.info(f"Quantized model pushed to: {hub_model_id}.") LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.") From b3f4aa149fa1d2a812c728b777e87822420ecde5 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:46:18 +0530 Subject: [PATCH 09/24] fix bin size (#3307) * fix bin size * lint --------- Co-authored-by: Ved --- src/axolotl/utils/data/streaming.py | 3 +++ src/axolotl/utils/samplers/multipack.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py index 2cb35ee7c..8b6b8a439 100644 --- a/src/axolotl/utils/data/streaming.py +++ b/src/axolotl/utils/data/streaming.py @@ -203,6 +203,7 @@ def wrap_streaming_dataset( max_seq_length=cfg.sequence_len, batch_size=cfg.micro_batch_size, multipack_attn=multipack_attn, + bin_size=cfg.sample_packing_bin_size, ) # Set this to 1 so downstream data_loader doesn't try to increase the batch size @@ -254,6 +255,7 @@ def encode_packed_streaming( collate_fn, ds_wrapper: Callable, examples: Dict[str, List], + bin_size: int, max_seq_length: int = 2048, batch_size: int = 4, multipack_attn: Optional[bool] = True, @@ -278,6 +280,7 @@ def encode_packed_streaming( batch_max_len=batch_size * max_seq_length, drop_last=True, num_processes=1, + bin_size=bin_size, ) chunked_data = defaultdict(list) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 662c63caa..436a49c79 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -260,12 +260,12 @@ class MultipackBatchSampler(BatchSampler): batch_size: int, # Number of bins per batch batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths + bin_size: int, # The max number of samples that can be packed in a single bin packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate drop_last: bool = True, # Whether to drop final batches (might be incomplete) num_count_samples: int = 4, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing - bin_size: int = 200, # The max number of samples that can be packed in a single bin num_processes: int | None = None, # Number of processes for parallel packing safe_mode: bool = True, # Conservative packing to prevent training instability mp_start_method: str = "fork", @@ -343,7 +343,7 @@ class MultipackBatchSampler(BatchSampler): lengths, bin_capacity=self.batch_max_len, group_size=self.group_size, - bin_size=self.bin_size, + bin_size=self.bin_size or self.batch_max_len, num_processes=min(4, num_processes) if num_processes else 4, safe_mode=self.safe_mode, mp_start_method=self.mp_start_method, From 4ac78aa562aacd9f3b568c5473180e5474f2354e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 9 Dec 2025 14:31:03 +0700 Subject: [PATCH 10/24] fix: update qwen3 jinja tokenization off a few tokens (#3295) * fix: update qwen3 jinja tokenization off a few tokens * fix: add note on tokenization issue * fix: pop last index for mistral tokenizer --- examples/qwen3/README.md | 46 +++++++++++++++++++ .../prompt_strategies/chat_template.py | 14 +++++- .../chat_templates/templates/qwen3.jinja | 8 +++- .../utils/mistral/mistral_tokenizer.py | 3 ++ 4 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 examples/qwen3/README.md diff --git a/examples/qwen3/README.md b/examples/qwen3/README.md new file mode 100644 index 000000000..a3d35881d --- /dev/null +++ b/examples/qwen3/README.md @@ -0,0 +1,46 @@ +# Finetune Qwen3 with Axolotl + +[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + + ```bash + axolotl train examples/qwen3/32b-qlora.yaml + ``` + +Let us know how it goes. Happy finetuning! πŸš€ + +### Chat template masking a few tokens off + +If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml. + +```yaml +chat_template: qwen3 +``` + +### TIPS + +- For inference, please check the official model card as it depends on your reasoning mode. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 28155810f..0fec64d81 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -95,6 +95,7 @@ class ChatTemplatePrompter(Prompter): add_generation_prompt=False, images=None, tools=None, + real_last_index=None, ): """ Build a prompt from a conversation. @@ -114,6 +115,9 @@ class ChatTemplatePrompter(Prompter): if tools: chat_template_kwargs["tools"] = tools + if real_last_index: + chat_template_kwargs["real_last_index"] = real_last_index + if self.processor: if not callable(self.processor): raise TypeError("Processor must be callable") @@ -631,11 +635,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turns_with_empty = turns[:turn_idx] + [empty_turn] turns_with_content = turns[: turn_idx + 1] + real_last_index = len(turns) - 1 + # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore + dummy_ids = self.prompter.build_prompt( + turns_with_empty, tools=tools, real_last_index=real_last_index + ) # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore + full_ids = self.prompter.build_prompt( + turns_with_content, tools=tools, real_last_index=real_last_index + ) # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") diff --git a/src/axolotl/utils/chat_templates/templates/qwen3.jinja b/src/axolotl/utils/chat_templates/templates/qwen3.jinja index 09b82ed03..77ea906e7 100644 --- a/src/axolotl/utils/chat_templates/templates/qwen3.jinja +++ b/src/axolotl/utils/chat_templates/templates/qwen3.jinja @@ -15,6 +15,12 @@ {%- endif %} {%- endif %} {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{#- Determine the real last index: use provided value or default to messages length - 1 #} +{%- if real_last_index is defined and real_last_index is not none %} + {%- set ns.real_last_index = real_last_index %} +{%- else %} + {%- set ns.real_last_index = messages|length - 1 %} +{%- endif %} {%- for message in messages[::-1] %} {%- set index = (messages|length - 1) - loop.index0 %} {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} @@ -37,7 +43,7 @@ {%- endif %} {%- endif %} {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and reasoning_content) %} + {%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %} {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} diff --git a/src/axolotl/utils/mistral/mistral_tokenizer.py b/src/axolotl/utils/mistral/mistral_tokenizer.py index af174cdac..3ce6be780 100644 --- a/src/axolotl/utils/mistral/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral/mistral_tokenizer.py @@ -80,6 +80,9 @@ class HFMistralTokenizer(MistralCommonTokenizer): ) -> str | list[int]: """Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg""" + # pop unnecessary kwarg for mistral + kwargs.pop("real_last_index", None) + try: if add_generation_prompt: self._set_mode(ValidationMode.serving) From 2a664dc8ad178d2881bf361e33559f19f30be829 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 11 Dec 2025 11:56:40 -0500 Subject: [PATCH 11/24] support for xformers wheels for torch 2.9 (#3308) * support for xformers wheels for torch 2.9 * fix hf cache? * don't use hf cache from s3 * show disk free space in ci --- .github/workflows/tests.yml | 30 +++++++++++++++++------------- .runpod/Dockerfile | 1 + setup.py | 1 - src/axolotl/cli/main.py | 3 ++- src/axolotl/utils/__init__.py | 5 +++++ 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 95370ca3d..1cbfc15e1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -66,12 +66,12 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 - - name: Restore Cache from S3 - id: hf-cache-restore-s3 - run: | - mkdir -p /home/runner/.cache/huggingface/hub - curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd - +# - name: Restore Cache from S3 +# id: hf-cache-restore-s3 +# run: | +# mkdir -p ~/.cache/huggingface/hub +# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd +# - name: Setup Python uses: actions/setup-python@v5 with: @@ -113,9 +113,13 @@ jobs: - name: Run tests run: | + df -h pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + df -h pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml + df -h pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml + df -h pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml - name: Upload coverage to Codecov @@ -145,12 +149,12 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 - - name: Restore Cache from S3 - id: hf-cache-restore-s3 - run: | - mkdir -p /home/runner/.cache/huggingface/hub - curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd - +# - name: Restore Cache from S3 +# id: hf-cache-restore-s3 +# run: | +# mkdir -p ~/.cache/huggingface/hub +# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd +# - name: Setup Python uses: actions/setup-python@v5 with: @@ -188,7 +192,7 @@ jobs: axolotl --help - name: Show HF cache - run: huggingface-cli scan-cache + run: hf cache scan - name: Run tests run: | diff --git a/.runpod/Dockerfile b/.runpod/Dockerfile index 107caf5f3..948d3f78e 100644 --- a/.runpod/Dockerfile +++ b/.runpod/Dockerfile @@ -10,6 +10,7 @@ ARG BASE_VOLUME="/runpod-volume" ENV BASE_VOLUME=$BASE_VOLUME ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets" ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub" +ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub" ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub" COPY .runpod/src /src diff --git a/setup.py b/setup.py index a1bdd6bdf..e22df40c8 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,6 @@ def parse_requirements(extras_require_map): extras_require_map.pop("fbgemm-gpu") extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"] extras_require_map["vllm"] = ["vllm==0.11.1"] - _install_requires.pop(_install_requires.index(xformers_version)) elif (major, minor) >= (2, 8): extras_require_map.pop("fbgemm-gpu") extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"] diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index dc6cca489..c0ac32050 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -26,7 +26,7 @@ from axolotl.cli.utils import ( launch_training, ) from axolotl.integrations.lm_eval.cli import lm_eval -from axolotl.utils import set_pytorch_cuda_alloc_conf +from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig @@ -45,6 +45,7 @@ def cli(): print_axolotl_text_art() load_dotenv() set_pytorch_cuda_alloc_conf() + set_misc_env() @cli.command() diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 72f8173f3..de67aadd0 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -51,6 +51,11 @@ def set_pytorch_cuda_alloc_conf(): ) +def set_misc_env(): + if os.getenv("XFORMERS_IGNORE_FLASH_VERSION_CHECK") is None: + os.environ["XFORMERS_IGNORE_FLASH_VERSION_CHECK"] = "1" + + def get_not_null(value, default=None): """ return the value if it's not None, otherwise return the default value From a1d07f42e41d819accefeff9969e0a7cbf26e52d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 17 Dec 2025 21:12:18 +0700 Subject: [PATCH 12/24] Fix(misc): address PYTORCH_CUDA_ALLOC_CONF deprecate (#3313) * fix: leftover ministral docs changes * fix: pytorch_cuda_alloc_conf deprecation * fix: set old PYTORCH_CUDA_ALLOC_CONF env too * handle 2.9 separately --------- Co-authored-by: Wing Lian --- README.md | 2 +- .../colab-axolotl-example.ipynb | 1 - requirements.txt | 2 +- src/axolotl/utils/__init__.py | 20 +++++++++++++------ 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 13518f2a8..285867215 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ ## πŸŽ‰ Latest Updates -- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral). +- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3). - 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss). - 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion). - 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107). diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 06705eb3d..77a4154e2 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -253,7 +253,6 @@ "source": [ "from axolotl.utils import set_pytorch_cuda_alloc_conf\n", "\n", - "# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n", "set_pytorch_cuda_alloc_conf()" ] }, diff --git a/requirements.txt b/requirements.txt index 21c94a3c2..0989325ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.5 # telemetry posthog==6.7.11 -mistral-common==1.8.5 +mistral-common==1.8.6 diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index de67aadd0..335049158 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -41,14 +41,22 @@ def get_pytorch_version() -> tuple[int, int, int]: def set_pytorch_cuda_alloc_conf(): - """Set up CUDA allocation config if using PyTorch >= 2.2""" + """Set up CUDA allocation config""" torch_version = torch.__version__.split(".") torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) - if torch_major == 2 and torch_minor >= 2: - if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( - "expandable_segments:True,roundup_power2_divisions:16" - ) + config_value = "expandable_segments:True,roundup_power2_divisions:16" + if ( + torch_major == 2 + and torch_minor >= 9 + and os.getenv("PYTORCH_ALLOC_CONF") is None + ): + os.environ["PYTORCH_ALLOC_CONF"] = config_value + elif ( + torch_major == 2 + and torch_minor >= 2 + and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None + ): + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value def set_misc_env(): From 83d4d97dccd4acf98a77f7c4c32d4a6f32a1a064 Mon Sep 17 00:00:00 2001 From: salman Date: Wed, 17 Dec 2025 15:35:22 +0100 Subject: [PATCH 13/24] Add QAT NVFP4 configs for blogpost (#3280) [skip ci] * add configs for blogpost * fix configs * fixing baseline configs --- examples/qat_nvfp4/Gemma3-12B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Gemma3-12B_qat.yml | 72 ++++++++++++++++++ .../qat_nvfp4/Math-Gemma3-12B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Math-Gemma3-12B_qat.yml | 72 ++++++++++++++++++ .../qat_nvfp4/Math-Gemma3-27B_baseline.yml | 68 +++++++++++++++++ examples/qat_nvfp4/Math-Gemma3-27B_qat.yml | 73 +++++++++++++++++++ .../qat_nvfp4/Math-Qwen2.5-72B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml | 72 ++++++++++++++++++ examples/qat_nvfp4/Qwen2.5-72B_baseline.yml | 67 +++++++++++++++++ examples/qat_nvfp4/Qwen2.5-72B_qat.yml | 72 ++++++++++++++++++ 10 files changed, 697 insertions(+) create mode 100644 examples/qat_nvfp4/Gemma3-12B_baseline.yml create mode 100644 examples/qat_nvfp4/Gemma3-12B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-12B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Gemma3-27B_qat.yml create mode 100644 examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml create mode 100644 examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml create mode 100644 examples/qat_nvfp4/Qwen2.5-72B_baseline.yml create mode 100644 examples/qat_nvfp4/Qwen2.5-72B_qat.yml diff --git a/examples/qat_nvfp4/Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Gemma3-12B_baseline.yml new file mode 100644 index 000000000..be4e86635 --- /dev/null +++ b/examples/qat_nvfp4/Gemma3-12B_baseline.yml @@ -0,0 +1,67 @@ +base_model: google/gemma-3-12b-it +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/out_gemma/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 4e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Gemma3-12B_qat.yml b/examples/qat_nvfp4/Gemma3-12B_qat.yml new file mode 100644 index 000000000..7fa81163f --- /dev/null +++ b/examples/qat_nvfp4/Gemma3-12B_qat.yml @@ -0,0 +1,72 @@ +base_model: google/gemma-3-12b-it +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/qat_out_gemma/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 4e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml new file mode 100644 index 000000000..9f209515b --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml @@ -0,0 +1,67 @@ +base_model: google/gemma-3-12b-it +# Math finetuning configuration for Gemma3-12B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_gemma/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 3e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml new file mode 100644 index 000000000..ef7e754be --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml @@ -0,0 +1,72 @@ +base_model: google/gemma-3-12b-it +# Math finetuning configuration for Gemma3-12B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_gemma/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 3e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml new file mode 100644 index 000000000..3a262d342 --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml @@ -0,0 +1,68 @@ +base_model: google/gemma-3-27b-it +# Math finetuning configuration for Gemma3-27B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_gemma27/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml new file mode 100644 index 000000000..87016ae9c --- /dev/null +++ b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml @@ -0,0 +1,73 @@ +base_model: google/gemma-3-27b-it +# Math finetuning configuration for Gemma3-27B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: gemma3 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_gemma27/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Gemma3DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml new file mode 100644 index 000000000..efec25c54 --- /dev/null +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-72B +# Math finetuning configuration for Qwen2.5-72B (non-instruct) +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/out_math_72b/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml new file mode 100644 index 000000000..427d7af52 --- /dev/null +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen2.5-72B +# Math finetuning configuration for Qwen2.5-72B (non-instruct) +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: AI-MO/NuminaMath-CoT + type: chat_template + +output_dir: ./outputs/qat_out_math_72b/ + +sequence_len: 4096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 8 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 5e-6 +eta_min: 7e-7 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml new file mode 100644 index 000000000..e1eaba61f --- /dev/null +++ b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml @@ -0,0 +1,67 @@ +base_model: Qwen/Qwen2.5-72B +# Alpaca finetuning configuration for Qwen2.5-72B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/out_qwen72b/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qat_nvfp4/Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml new file mode 100644 index 000000000..dad7e5422 --- /dev/null +++ b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen2.5-72B +# Alpaca finetuning configuration for Qwen2.5-72B +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +plugins: + - axolotl.integrations.liger.LigerPlugin + +liger_rope: true +liger_rms_norm: true +liger_glu_activation: true +liger_layer_norm: true +liger_fused_linear_cross_entropy: true +seed: 42 +chat_template: qwen_25 +datasets: + - path: tatsu-lab/alpaca + type: alpaca + +output_dir: ./outputs/qat_out_qwen72b/ + +sequence_len: 8096 +sample_packing: true +flash_attention: true + +qat: + activation_dtype: nvfp4 + weight_dtype: nvfp4 + group_size: 16 # only group_size of 16 is supported with nvfp4 + +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 16 + +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: true +tf32: true + +resume_from_checkpoint: +logging_steps: 1 + +# evals_per_epoch: 1 +saves_per_epoch: 1 + +warmup_ratio: 0.1 +weight_decay: 0.0 +fsdp_version: 2 + +fsdp_config: + offload_params: false + cpu_ram_efficient_loading: true + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: Qwen2DecoderLayer + state_dict_type: FULL_STATE_DICT + sharding_strategy: FULL_SHARD + reshard_after_forward: true + activation_checkpointing: true + +special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config From 2cf254b4afc2c3eea632a09448bd6b999ca50e93 Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:09:39 -0500 Subject: [PATCH 14/24] Add `peft_autocast_adapter_dtype` config option (#3311) [skip ci] * Add `peft_autocast_adapter_dtype` field to schema * Add `autocast_adapter_dtype` to `model_kwargs` * chore: docs --------- Co-authored-by: NanoCode012 --- src/axolotl/loaders/adapter.py | 7 +++++-- src/axolotl/utils/schemas/peft.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index dca688bb2..3b64b23db 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -142,9 +142,12 @@ def load_lora( ): setup_quantized_meta_for_peft(model) + model_kwargs: Any = {} + if cfg.peft_autocast_adapter_dtype is not None: + model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} model_kwargs["device_map"] = {"": "cpu"} @@ -155,7 +158,7 @@ def load_lora( **model_kwargs, ) else: - model = get_peft_model(model, lora_config) + model = get_peft_model(model, lora_config, **model_kwargs) if rank == 0: try: diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index fd16dec3f..a9ce1fbd6 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -109,6 +109,12 @@ class LoraConfig(BaseModel): ) }, ) + peft_autocast_adapter_dtype: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT." + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False, From 3e51a680c297e7b7e86a81cbdd3179cd756bd972 Mon Sep 17 00:00:00 2001 From: Seung Hyun Cho Date: Thu, 18 Dec 2025 03:40:36 +0900 Subject: [PATCH 15/24] fix: Fix evaluation loss in KD trainer (#3271) * fix: Fix evaluation loss in KD trainer * Fix v2 strategy super() call * fix: Add safety check for total_tokens in log method * fix: simplified num items and outputs return handling * fix: add missing model forward pass in compute_loss * refactor: Use Template Method pattern for chat template strategies * refactor: use pop(None) and remove v2 override * chore: lint --------- Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/core/trainers/base.py | 6 +- src/axolotl/integrations/kd/chat_template.py | 22 ++++-- src/axolotl/integrations/kd/trainer.py | 17 +++- tests/integrations/test_kd_chat_template.py | 81 ++++++++++++++++++++ 4 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 tests/integrations/test_kd_chat_template.py diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 7896c6088..f4414d649 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -631,7 +631,11 @@ class AxolotlTrainer( logs["tokens_per_second_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) - logs["total_tokens"] = int(self.state.total_tokens.item()) + if ( + hasattr(self.state, "total_tokens") + and self.state.total_tokens is not None + ): + logs["total_tokens"] = int(self.state.total_tokens.item()) del self._stored_metrics[train_eval] diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 04f0f24a4..5cae69e7c 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): logprobs = prompt.pop(self.logprobs_field) tokenized_prompt = super()._tokenize_single_prompt(prompt) tokenized_prompt[self.logprobs_field] = logprobs - tokenized_prompt = self.transform_logprobs(tokenized_prompt) + # let subclasses add fields before transform + tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt) + + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + return tokenized_prompt + + def _prepare_kd_fields(self, tokenized_prompt, original_prompt): + """ + Hook for subclasses to prepare additional KD fields before transform + """ return tokenized_prompt @@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): return sample - def _tokenize_single_prompt(self, prompt): - target_token_ids = prompt.get("target_token_ids", None) - - tokenized_prompt = super()._tokenize_single_prompt(prompt) - + def _prepare_kd_fields(self, tokenized_prompt, original_prompt): + """ + Add pre-tokenized target_token_ids for v2 format + """ + target_token_ids = original_prompt.pop("target_token_ids", None) if target_token_ids is not None: tokenized_prompt["target_token_ids"] = target_token_ids - return tokenized_prompt diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 0e98497a7..343d4c6df 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -16,6 +16,8 @@ KD trainer """ +from typing_extensions import override + from axolotl.core.trainers.base import AxolotlTrainer from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss @@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer): if columns_to_add: self._signature_columns += columns_to_add + @override def compute_loss( self, model, @@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer): ): del inputs["attention_mask"] + if num_items_in_batch is None and "labels" in inputs: + num_items_in_batch = (inputs["labels"] != -100).sum().item() + if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) - return outputs[0] + + if isinstance(outputs, dict): + loss = outputs["loss"] + elif isinstance(outputs, tuple): + loss = outputs[0] + else: + loss = outputs.loss if hasattr(outputs, "loss") else outputs + + return (loss, outputs) if return_outputs else loss diff --git a/tests/integrations/test_kd_chat_template.py b/tests/integrations/test_kd_chat_template.py new file mode 100644 index 000000000..b828e6c3d --- /dev/null +++ b/tests/integrations/test_kd_chat_template.py @@ -0,0 +1,81 @@ +""" +Test for KD chat template strategies +""" + +from unittest.mock import Mock + +import pytest + +from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2 + + +class TestChatTemplateStrategyWithKDv2: + """Test v2 strategy correctly handles target_token_ids""" + + @pytest.fixture + def v2_strategy(self): + """Create v2 strategy instance with mocked dependencies""" + # Mock prompter + mock_prompter = Mock() + mock_prompter.roles = {"user": "user", "assistant": "assistant"} + mock_prompter.chat_template_msg_variables = ["role", "content"] + mock_prompter.chat_template = "{{ messages }}" + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.pad_token_id = 0 + mock_tokenizer.eos_token_id = 2 + mock_tokenizer.bos_token_id = 1 + mock_tokenizer.eos_token = "<|endoftext|>" + mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2]) + mock_tokenizer.encode = Mock(return_value=[2]) + + return ChatTemplateStrategyWithKDv2( + prompter=mock_prompter, + tokenizer=mock_tokenizer, + train_on_inputs=False, + sequence_len=512, + logprobs_field="logprobs", + gen_temperature=1.0, + kd_temperature=1.0, + ) + + def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy): + """ + Test that v2's _prepare_kd_fields hook adds target_token_ids. + + Validates the Template Method pattern fix where v2 overrides + the hook to add target_token_ids before transform. + """ + tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]} + original = {"target_token_ids": [[10, 20], [30, 40]]} + + result = v2_strategy._prepare_kd_fields(tokenized, original) + + assert "target_token_ids" in result + assert result["target_token_ids"] == [[10, 20], [30, 40]] + + def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy): + """Test hook handles missing target_token_ids gracefully""" + tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]} + original = {} + + result = v2_strategy._prepare_kd_fields(tokenized, original) + + assert "target_token_ids" not in result + + def test_v2_transform_requires_target_token_ids(self, v2_strategy): + """ + Test v2's transform fails without target_token_ids. + + Validates the bug fix - transform expects target_token_ids + to be added by the hook. + """ + sample = { + "input_ids": [1, 10, 20, 30, 2], + "labels": [1, 10, 20, 30, 2], + "logprobs": [[-0.1, -0.2], [-0.3, -0.4]], + } + + with pytest.raises(KeyError, match="target_token_ids"): + v2_strategy.transform_logprobs(sample) From 2197b0bf89dd471b8d77485d18c3992e3d1cda94 Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Thu, 18 Dec 2025 09:02:41 -0500 Subject: [PATCH 16/24] feat: cheap ppl metric (#3317) * Import math and compute perplexity from loss values * lint * coderabbit changes * lint * fix: add rounding to ppl --------- Co-authored-by: NanoCode012 --- src/axolotl/core/trainers/base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index f4414d649..8adafd42d 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import os from collections import defaultdict from functools import partial, wraps @@ -615,6 +616,17 @@ class AxolotlTrainer( ) logs[key] = round(fn(values).item(), 4) + if "loss" in logs: + try: + logs["ppl"] = round(math.exp(logs["loss"]), 4) + except OverflowError: + logs["ppl"] = float("inf") + if "eval_loss" in logs: + try: + logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4) + except OverflowError: + logs["eval_ppl"] = float("inf") + if is_main_process(): # Add memory usage try: From 3750d7dd64a2c0699d5009392d85de525f7414c6 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 18 Dec 2025 21:41:06 +0530 Subject: [PATCH 17/24] add liger support kernal for dpo (#3302) * add liger kernal 4 dpo * revert grpo changes,add support in dpo * revert grpo changes,add support in dpo * dpo_use_liger_kernal * fix liger_dpo --------- Co-authored-by: Ved --- src/axolotl/core/trainers/dpo/__init__.py | 2 ++ src/axolotl/utils/schemas/config.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 3aa79c484..5e160e692 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -36,4 +36,6 @@ class DPOStrategy: training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss if cfg.dpo_use_logits_to_keep is not None: training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep + if cfg.dpo_use_liger_kernel is not None: + training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel return training_args_kwargs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c9b087ea3..bd6a61177 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -173,6 +173,12 @@ class AxolotlInputConfig( dpo_use_logits_to_keep: bool | None = None dpo_label_smoothing: float | None = None dpo_norm_loss: bool | None = None + + dpo_use_liger_kernel: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."}, + ) + dpo_padding_free: bool | None = None dpo_generate_during_eval: bool | None = None From bbd3486f57ab7894ecf8db62527c1d28a61d22fc Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 19 Dec 2025 16:43:47 +0100 Subject: [PATCH 18/24] Distributed Muon Optimizer (#3264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init * working * updating configs * removing unneeded files * lint * comments * lint * fix regex match * bump contribs version * comments * fixing tests and imports * muon imports in test v2 * test cleanup * bump contribs version --------- Co-authored-by: Salman Mohammadi <β€œsalman.mohammadi@outlook.com”> --- examples/qwen2/adamw-pretrain-fsdp2.yaml | 70 ++++++++ examples/qwen2/muon-pretrain-fsdp2.yaml | 70 ++++++++ requirements.txt | 3 +- src/axolotl/core/builders/base.py | 19 ++- src/axolotl/utils/schemas/validation.py | 87 +++++----- tests/core/test_builders.py | 12 +- tests/e2e/multigpu/test_dist_muon_fsdp2.py | 168 ++++++++++++++++++++ tests/test_validation_dataset.py | 2 +- tests/utils/schemas/validation/test_fsdp.py | 11 ++ 9 files changed, 387 insertions(+), 55 deletions(-) create mode 100644 examples/qwen2/adamw-pretrain-fsdp2.yaml create mode 100644 examples/qwen2/muon-pretrain-fsdp2.yaml create mode 100644 tests/e2e/multigpu/test_dist_muon_fsdp2.py diff --git a/examples/qwen2/adamw-pretrain-fsdp2.yaml b/examples/qwen2/adamw-pretrain-fsdp2.yaml new file mode 100644 index 000000000..43fb17aab --- /dev/null +++ b/examples/qwen2/adamw-pretrain-fsdp2.yaml @@ -0,0 +1,70 @@ +base_model: Qwen/Qwen2.5-0.5B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +# Use random initialization for fair comparison +reinit_weights: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +# Pretraining dataset +pretraining_dataset: + - path: allenai/c4 + name: en + type: pretrain + split: train + +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/compare-adamw-pretrain + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: dist_muon +wandb_entity: +wandb_watch: +wandb_name: adamw +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 4 +num_epochs: 1 +max_steps: 305 + +# AdamW optimizer settings (standard LR for AdamW) +optimizer: adamw_torch_fused +learning_rate: 0.0002 +weight_decay: 0.01 +lr_scheduler: cosine + +train_on_inputs: true +group_by_length: false +bf16: auto +fp16: false +tf32: false + +gradient_checkpointing: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 0 +saves_per_epoch: 1 + +# Reproducibility +seed: 42 + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_reshard_after_forward: true + +special_tokens: diff --git a/examples/qwen2/muon-pretrain-fsdp2.yaml b/examples/qwen2/muon-pretrain-fsdp2.yaml new file mode 100644 index 000000000..35c0b71f4 --- /dev/null +++ b/examples/qwen2/muon-pretrain-fsdp2.yaml @@ -0,0 +1,70 @@ +base_model: Qwen/Qwen2.5-0.5B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +# Use random initialization for fair comparison +reinit_weights: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +# Pretraining dataset +pretraining_dataset: + - path: allenai/c4 + name: en + type: pretrain + split: train + +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/compare-muon-pretrain + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: dist_muon +wandb_entity: +wandb_watch: +wandb_name: muon +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 4 +num_epochs: 1 +max_steps: 305 + +# Muon optimizer settings +optimizer: muon +learning_rate: 0.02 +weight_decay: 0.01 +lr_scheduler: cosine + +train_on_inputs: true +group_by_length: false +bf16: auto +fp16: false +tf32: false + +gradient_checkpointing: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 0 +saves_per_epoch: 1 + +# Reproducibility +seed: 42 + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_reshard_after_forward: true + +special_tokens: diff --git a/requirements.txt b/requirements.txt index 0989325ac..093546815 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,8 +67,7 @@ openenv-core==0.1.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.7 -axolotl-contribs-mit==0.0.5 - +axolotl-contribs-mit==0.0.6 # telemetry posthog==6.7.11 diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 0d19b369f..06d15ffc8 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -281,11 +281,22 @@ class TrainerBuilderBase(abc.ABC): adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") if self.cfg.optimizer == "muon": - from axolotl.contribs.mit.muon import ( - MuonOptimizerFactory, - ) + _, device_mesh = build_parallelism_config(self.cfg) + + if device_mesh is not None: + from axolotl.contribs.mit.muon.dist_muon import ( + DistMuonOptimizerFactory, + ) + + optimizer_cls = DistMuonOptimizerFactory + optimizer_kwargs["device_mesh"] = device_mesh + else: + from axolotl.contribs.mit.muon import ( + MuonOptimizerFactory, + ) + + optimizer_cls = MuonOptimizerFactory - optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "dion": from axolotl.contribs.mit.dion import ( diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 368976831..36565fb03 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -751,12 +751,19 @@ class OptimizationValidationMixin: @model_validator(mode="before") @classmethod def check_muon_deepspeed_fsdp(cls, data): - if data.get("optimizer") == "muon" and ( - data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") - ): - raise ValueError( - "Muon optimizer is currently incompatible with DeepSpeed and FSDP" - ) + if data.get("optimizer") == "muon": + if data.get("deepspeed"): + raise ValueError( + "Muon optimizer is currently incompatible with DeepSpeed" + ) + if data.get("fsdp") or data.get("fsdp_config"): + fsdp_version = data.get("fsdp_version") + if fsdp_version is None: + fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + if str(fsdp_version) != "2": + raise ValueError( + "Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP." + ) return data @model_validator(mode="before") @@ -840,40 +847,6 @@ class OptimizationValidationMixin: return data - @model_validator(mode="before") - @classmethod - def check_fsdp_version_in_fsdp_config(cls, data): - fsdp_config = data.get("fsdp_config") or {} - if fsdp_config and fsdp_config.get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = fsdp_config.pop("fsdp_version") - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_config_kwargs_prefix(cls, data): - if fsdp_config := data.get("fsdp_config"): - should_fix = False - for key, _ in fsdp_config.items(): - if key.startswith("fsdp_"): - should_fix = True - LOG.warning_once( - "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " - "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." - ) - if should_fix: - update_fsdp_config = {} - for key, value in fsdp_config.items(): - if key.startswith("fsdp_") and key != "fsdp_version": - update_fsdp_config[key.replace("fsdp_", "")] = value - else: - update_fsdp_config[key] = value - data["fsdp_config"] = update_fsdp_config - return data - @model_validator(mode="after") def check_fsdp_offload_w_8bit_optimizer(self): if ( @@ -975,6 +948,40 @@ class OptimizationValidationMixin: return data + @model_validator(mode="before") + @classmethod + def check_fsdp_version_in_fsdp_config(cls, data): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_config_kwargs_prefix(cls, data): + if fsdp_config := data.get("fsdp_config"): + should_fix = False + for key, _ in fsdp_config.items(): + if key.startswith("fsdp_"): + should_fix = True + LOG.warning_once( + "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " + "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." + ) + if should_fix: + update_fsdp_config = {} + for key, value in fsdp_config.items(): + if key.startswith("fsdp_") and key != "fsdp_version": + update_fsdp_config[key.replace("fsdp_", "")] = value + else: + update_fsdp_config[key] = value + data["fsdp_config"] = update_fsdp_config + return data + class SystemValidationMixin: """Validation methods related to system and hardware configuration.""" diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 199777896..f9db4d013 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -474,10 +474,8 @@ def rand_reward_func(prompts, completions) -> list[float]: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( - Muon, - MuonOptimizerFactory, - ) + from axolotl.contribs.mit.muon import MuonOptimizerFactory + from axolotl.contribs.mit.muon.muon import Muon optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs assert optimizer_cls is MuonOptimizerFactory @@ -556,10 +554,8 @@ class TestHFCausalTrainerBuilder: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( - Muon, - MuonOptimizerFactory, - ) + from axolotl.contribs.mit.muon import MuonOptimizerFactory + from axolotl.contribs.mit.muon.muon import Muon optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs assert optimizer_cls is MuonOptimizerFactory diff --git a/tests/e2e/multigpu/test_dist_muon_fsdp2.py b/tests/e2e/multigpu/test_dist_muon_fsdp2.py new file mode 100644 index 000000000..93db473a9 --- /dev/null +++ b/tests/e2e/multigpu/test_dist_muon_fsdp2.py @@ -0,0 +1,168 @@ +"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality.""" + +import os +from pathlib import Path + +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_training_success(temp_dir): + """Verify that training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert len(checkpoint_files) > 0, ( + "No checkpoint files found - training may have failed" + ) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan(torch.tensor(final_loss)), ( + f"Training loss is NaN: {final_loss}" + ) + + +class TestDistMuon: + """Test class for DistMuon optimizer with FSDP2 functionality.""" + + @require_torch_2_7_0 + def test_fft_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.02, + "optimizer": "muon", + "weight_decay": 0.01, + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_lora_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.02, + "optimizer": "muon", + "weight_decay": 0.01, + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 3d3b5db96..464812a90 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation): } ) - with pytest.raises(ValueError, match=r".*is currently incompatible with*"): + with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"): validate_config(cfg) diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py index 65f9c66a3..9fa327797 100644 --- a/tests/utils/schemas/validation/test_fsdp.py +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -123,6 +123,17 @@ class TestFSDPValidation: assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer" assert cfg.fsdp_config.reshard_after_forward is True + def test_muon_fsdp1_rejected(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + optimizer="muon", + fsdp_version=1, + fsdp_config={"reshard_after_forward": True}, + ) + with pytest.raises( + ValueError, match="Muon optimizer is only compatible with FSDP2" + ): + validate_config(cfg) + @pytest.mark.parametrize( "rl", [ From 07c41a6c2a82b42152de9cf8b0c8b82b43ed9862 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Dec 2025 11:34:55 -0500 Subject: [PATCH 19/24] fix preview docs failing due to running out of disk (#3326) [skip ci] * fix preview docs failing due to running out of disk * fix docs publish too --- .github/workflows/docs.yml | 3 +++ .github/workflows/preview-docs.yml | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5b5cc5489..f4a4144ba 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -12,6 +12,9 @@ jobs: build-deploy: runs-on: ubuntu-latest steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL - name: Check out repository uses: actions/checkout@v4 - name: Set up Quarto diff --git a/.github/workflows/preview-docs.yml b/.github/workflows/preview-docs.yml index db4abddce..604998130 100644 --- a/.github/workflows/preview-docs.yml +++ b/.github/workflows/preview-docs.yml @@ -11,6 +11,7 @@ on: - '_quarto.yml' - docs/scripts/generate_config_docs.py - src/axolotl/utils/schemas/**.py + - .github/workflows/preview-docs.yml permissions: checks: write @@ -27,6 +28,10 @@ jobs: runs-on: ubuntu-latest if: ${{ !github.event.pull_request.draft }} steps: + - name: cleanup node + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + - name: Check out repository uses: actions/checkout@v4 with: From 43cef27458a4e8395f14bae96ff1be2a7b7fb7a2 Mon Sep 17 00:00:00 2001 From: Alexander Kozhevnikov Date: Mon, 22 Dec 2025 16:53:58 +0300 Subject: [PATCH 20/24] Fix typo in densemixer RuntimeError (#3327) [skip ci] It offers installing densemizer while it should be densemixer --- src/axolotl/integrations/densemixer/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/integrations/densemixer/plugin.py b/src/axolotl/integrations/densemixer/plugin.py index 2d0bf32cd..9548fd19a 100644 --- a/src/axolotl/integrations/densemixer/plugin.py +++ b/src/axolotl/integrations/densemixer/plugin.py @@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin): if cfg.dense_mixer: if not importlib.util.find_spec("densemixer"): raise RuntimeError( - "DenseMixer is not installed. Install it with `pip install densemizer`" + "DenseMixer is not installed. Install it with `pip install densemixer`" ) from densemixer.patching import ( From faaff6c7929f948ec4f6ea9dd9816b7430f03a2a Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 22 Dec 2025 19:24:43 +0530 Subject: [PATCH 21/24] allow users to set ndigits for rounding of metrics when logging (#3325) * METRIC_PRECISION-> 8 * use ndigits and move env getter to top of log function --------- Co-authored-by: Ved Co-authored-by: Wing Lian --- src/axolotl/core/trainers/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8adafd42d..aae3d28fb 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -604,6 +604,7 @@ class AxolotlTrainer( """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" + metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5")) for key, metric_data in self._stored_metrics[train_eval].items(): values = torch.tensor(metric_data["values"]) # type: ignore[arg-type] @@ -614,16 +615,16 @@ class AxolotlTrainer( raise NotImplementedError( "Metric reduction must be one of [mean, min, max, sum]" ) - logs[key] = round(fn(values).item(), 4) + logs[key] = round(fn(values).item(), metric_ndigits) if "loss" in logs: try: - logs["ppl"] = round(math.exp(logs["loss"]), 4) + logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits) except OverflowError: logs["ppl"] = float("inf") if "eval_loss" in logs: try: - logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), 4) + logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits) except OverflowError: logs["eval_ppl"] = float("inf") From efeb5a4e41007a1e87e7ee780590938c94665899 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 Dec 2025 13:58:25 -0500 Subject: [PATCH 22/24] fix check for fp8 capability (#3324) * fix check for fp8 capability * handle non-cuda compute * reduce concurrency of tests --- .github/workflows/tests.yml | 4 ++-- examples/llama-3/3b-fp8-fsdp2.yaml | 1 - src/axolotl/cli/config.py | 9 +++++++++ src/axolotl/utils/schemas/config.py | 11 +++++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cbfc15e1..0dc61b7ff 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -114,7 +114,7 @@ jobs: - name: Run tests run: | df -h - pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml df -h pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml df -h @@ -196,7 +196,7 @@ jobs: - name: Run tests run: | - pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml + pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml pytest -v --durations=10 tests/cli/ diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index b7de7ca52..57b308abd 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -29,7 +29,6 @@ flex_attention: true flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs -save_strategy: no torch_compile: true wandb_project: diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 3c4ace7b0..b53c6576b 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -227,6 +227,7 @@ def load_cfg( cfg, capabilities={ "bf16": is_torch_bf16_gpu_available(), + "fp8": compute_supports_fp8(), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, @@ -259,3 +260,11 @@ def load_cfg( ) return cfg + + +def compute_supports_fp8() -> bool: + try: + compute_capability = torch.cuda.get_device_capability() + return compute_capability >= (9, 0) + except RuntimeError: + return False diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index bd6a61177..e0c9acd4d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -2,6 +2,7 @@ from typing import Annotated, Any, Literal +from accelerate.utils import is_fp8_available from annotated_types import MinLen from packaging import version from pydantic import ( @@ -1098,6 +1099,16 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self + @model_validator(mode="after") + def check_fp8(self): + if self.fp8 and not self.capabilities.fp8: + raise ValueError("fp8 requested, but fp8 is not supported on this GPU") + elif self.fp8 and self.capabilities.fp8 and not is_fp8_available(): + raise ValueError( + "fp8 requested, but missing one of ms-amp, transformers-engine or torchao." + ) + return self + @model_validator(mode="before") @classmethod def check_sample_packing_w_sdpa_bf16(cls, data): From 92ee4256f73159dc20a204ba5186ebff002658ae Mon Sep 17 00:00:00 2001 From: kallewoof Date: Tue, 23 Dec 2025 03:59:49 +0900 Subject: [PATCH 23/24] feature: raise on long sequence drop (#3321) * feature: raise on long sequence drop It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated. * tests: add excess_length_strategy tests * doc: updated return value description for drop_long_seq_in_dataset * add @enable_hf_offline * fixed cfg modified after validate_config called * hf offline fix * fix tqdm desc when raise is used * test: added test for non-batched case * accidental code change revert * test: use pytest.raises * test: simplified drop_seq_len tests * test: moved excess_length_strat test to test_data.py --------- Co-authored-by: salman --- src/axolotl/utils/data/utils.py | 16 ++++++++++--- src/axolotl/utils/schemas/config.py | 4 ++-- src/axolotl/utils/trainer.py | 13 +++++++++- tests/test_data.py | 37 +++++++++++++++++++++++++++++ tests/test_datasets.py | 4 +++- 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 2d0ca9d0e..319e27f6f 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -188,7 +188,10 @@ def handle_long_seq_in_dataset( cfg: Dictionary mapping `axolotl` config keys to values. Returns: - Filtered dataset with long sequences removed. + Filtered dataset with long sequences handled according to the excess_length_strategy value: + 'drop' (default) excludes any sequence longer than sequence_len + 'truncate' truncates them down to sequence_len + 'raise' raises a ValueError if any sequence was found that was longer than sequence_len """ if ( hasattr(dataset, "column_names") @@ -206,10 +209,13 @@ def handle_long_seq_in_dataset( ) return dataset + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + drop_long = functools.partial( drop_long_seq, sequence_len=sequence_len, min_sequence_len=cfg.min_sample_len, + raise_on_drop=excess_length_strategy == "raise", ) with contextlib.suppress(AttributeError): @@ -228,9 +234,13 @@ def handle_long_seq_in_dataset( drop_long_kwargs = {} if filter_map_kwargs: - drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" + action = ( + "Checking Sequence Lengths" + if excess_length_strategy == "raise" + else "Dropping Long Sequences" + ) + drop_long_kwargs["desc"] = f"{action} (>{sequence_len})" - excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() if excess_length_strategy == "truncate": process_fn = functools.partial( truncate_long_seq, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e0c9acd4d..f2f4a311a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -452,10 +452,10 @@ class AxolotlInputConfig( "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" }, ) - excess_length_strategy: Literal["drop", "truncate"] | None = Field( + excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field( default=None, json_schema_extra={ - "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility." + "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility." }, ) eval_sequence_len: int | None = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d97577d86..3628fd85f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,12 +205,15 @@ def add_length(sample): return sample -def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False): """ Drop samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len). Works for both single-example (list[int]) or batched (list[list[int]]). + + If raise_on_drop is set, the code raises a ValueError if a sample is + encountered that is too long and would have been dropped. """ min_sequence_len = min_sequence_len or 2 @@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): if isinstance(input_ids[0], int): # Single example (input_ids is a list of int) length = len(input_ids) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) return min_sequence_len <= length <= sequence_len # Batched (input_ids is a list of lists) results = [] for seq in input_ids: length = len(seq) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) results.append(min_sequence_len <= length <= sequence_len) return results diff --git a/tests/test_data.py b/tests/test_data.py index 99ed06336..ad76bbf6e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -7,6 +7,7 @@ import unittest from transformers import LlamaTokenizer from axolotl.utils.data import encode_streaming, md5 +from axolotl.utils.trainer import drop_long_seq from tests.hf_offline_utils import enable_hf_offline @@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase): md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3" ) + def test_excess_length_strategy(self): + """Test that excess_length_strategy results in a value error when set to 'raise'.""" + + # -- single sequence -- + # This should work + data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]} + drop_long_seq(data, 32, raise_on_drop=True) + + # This should return True, since data fits + dropped = drop_long_seq(data, 32) + self.assertTrue(dropped) + + # This should raise + self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + + # This should return False, since data doesn't fit + dropped = drop_long_seq(data, 15) + self.assertFalse(dropped) + + # -- batch sequence -- + # This should work + data = { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ] + } + drop_long_seq(data, 32, raise_on_drop=True) + + # This should raise + self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + + # This should keep the first but drop the second entry + dropped = drop_long_seq(data, 15) + self.assertEqual(dropped, [True, False]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd1c8f2c2..3b24ad580 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -13,7 +13,9 @@ from transformers import PreTrainedTokenizer from axolotl.loaders.tokenizer import load_tokenizer from axolotl.utils.data.rl import prepare_preference_datasets -from axolotl.utils.data.sft import _load_tokenized_prepared_datasets +from axolotl.utils.data.sft import ( + _load_tokenized_prepared_datasets, +) from axolotl.utils.dict import DictDefault from tests.constants import ( From f2155eaf79bd01ceb379196b9b9298e9b740404a Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 23 Dec 2025 05:49:07 -0800 Subject: [PATCH 24/24] feat: add trackio as experiment tracking integration (#3253) * feat: add trackio as experiment tracking integration - Add TrackioConfig to integrations schema with project_name, run_name, and space_id - Create trackio_.py module for environment setup - Add is_trackio_available() utility function - Integrate trackio with report_to in trainer builder - Add trackio callback for experiment tracking - Add trackio config keys to gpt-oss example YAMLs - Trackio runs locally by default, syncs to HF Space if space_id provided * changes * changes * changes * changes * changes * changes * changes * Update requirements.txt * don't allow pydantic 2.12 for now --------- Co-authored-by: Abubakar Abid Co-authored-by: Wing Lian --- .../gpt-oss-120b-fft-fsdp2-offload.yaml | 4 ++ .../gpt-oss-20b-fft-deepspeed-zero3.yaml | 4 ++ .../gpt-oss-20b-fft-fsdp2-offload.yaml | 4 ++ examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml | 4 ++ .../gpt-oss-20b-sft-lora-singlegpu.yaml | 4 ++ ...-oss-safeguard-20b-sft-lora-singlegpu.yaml | 4 ++ requirements.txt | 7 +-- src/axolotl/cli/config.py | 2 + src/axolotl/cli/inference.py | 4 +- src/axolotl/cli/utils/diffusion.py | 4 +- src/axolotl/core/builders/base.py | 13 ++++++ src/axolotl/utils/__init__.py | 4 ++ src/axolotl/utils/callbacks/trackio_.py | 44 +++++++++++++++++++ src/axolotl/utils/schemas/config.py | 2 + src/axolotl/utils/schemas/integrations.py | 20 +++++++++ src/axolotl/utils/trackio_.py | 17 +++++++ 16 files changed, 134 insertions(+), 7 deletions(-) create mode 100644 src/axolotl/utils/callbacks/trackio_.py create mode 100644 src/axolotl/utils/trackio_.py diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index 62f3167e8..b7082f986 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -32,6 +32,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml index ccb84e28e..b718ff2eb 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml @@ -28,6 +28,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml index 69a3c434d..af1c93bc0 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -29,6 +29,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml index 4a0f1ad70..894ba99b8 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -28,6 +28,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 2 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml index b6deacb1b..7c4f97846 100644 --- a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -41,6 +41,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 diff --git a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml index ab026337d..cbb9efc8e 100644 --- a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml @@ -41,6 +41,10 @@ wandb_watch: wandb_name: wandb_log_model: +trackio_project_name: +trackio_run_name: +trackio_space_id: + gradient_accumulation_steps: 8 micro_batch_size: 1 num_epochs: 1 diff --git a/requirements.txt b/requirements.txt index 093546815..5e1af6940 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,15 +20,16 @@ deepspeed>=0.17.0 trl==0.25.0 hf_xet==1.2.0 kernels>=0.9.0 -trackio +trackio>=0.13.0 +typing_extensions>=4.14.0 optimum==1.16.2 hf_transfer sentencepiece -gradio==5.49.1 +gradio>=6.2.0,<7.0 modal==1.0.2 -pydantic>=2.10.6 +pydantic>=2.10.6,<2.12 addict fire PyYAML>=6.0 diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index b53c6576b..986167f02 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.tee import prepare_debug_log +from axolotl.utils.trackio_ import setup_trackio_env_vars from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -246,6 +247,7 @@ def load_cfg( setup_wandb_env_vars(cfg) setup_mlflow_env_vars(cfg) setup_comet_env_vars(cfg) + setup_trackio_env_vars(cfg) plugin_set_cfg(cfg) TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 640be3696..cafa0f4ef 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -288,8 +288,8 @@ def do_inference_gradio( title=cfg.get("gradio_title", "Axolotl Gradio Interface"), ) - demo.queue().launch( - show_api=False, + demo.launch( + footer_links=["gradio", "settings"], share=cfg.get("gradio_share", True), server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_port=cfg.get("gradio_server_port", None), diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py index 1157bfd66..7bf68048e 100644 --- a/src/axolotl/cli/utils/diffusion.py +++ b/src/axolotl/cli/utils/diffusion.py @@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui( outputs=[masked_preview, html_out], ) - demo.queue().launch( - show_api=False, + demo.launch( + footer_links=["gradio", "settings"], share=cfg.get("gradio_share", True), server_name=cfg.get("gradio_server_name", "127.0.0.1"), server_port=cfg.get("gradio_server_port", None), diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 06d15ffc8..412f6da2f 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -35,6 +35,7 @@ from axolotl.utils import ( is_comet_available, is_mlflow_available, is_opentelemetry_available, + is_trackio_available, ) from axolotl.utils.callbacks import ( GCCallback, @@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_trackio and is_trackio_available(): + from axolotl.utils.callbacks.trackio_ import ( + SaveAxolotlConfigtoTrackioCallback, + ) + + callbacks.append( + SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path) + ) if self.cfg.use_otel_metrics and is_opentelemetry_available(): from axolotl.utils.callbacks.opentelemetry import ( OpenTelemetryMetricsCallback, @@ -434,6 +443,8 @@ class TrainerBuilderBase(abc.ABC): report_to.append("tensorboard") if self.cfg.use_comet: report_to.append("comet_ml") + if self.cfg.use_trackio: + report_to.append("trackio") training_args_kwargs["report_to"] = report_to @@ -441,6 +452,8 @@ class TrainerBuilderBase(abc.ABC): training_args_kwargs["run_name"] = self.cfg.wandb_name elif self.cfg.use_mlflow: training_args_kwargs["run_name"] = self.cfg.mlflow_run_name + elif self.cfg.use_trackio: + training_args_kwargs["run_name"] = self.cfg.trackio_run_name else: training_args_kwargs["run_name"] = None diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 335049158..96ac29bd0 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -24,6 +24,10 @@ def is_opentelemetry_available(): ) +def is_trackio_available(): + return importlib.util.find_spec("trackio") is not None + + def get_pytorch_version() -> tuple[int, int, int]: """ Get Pytorch version as a tuple of (major, minor, patch). diff --git a/src/axolotl/utils/callbacks/trackio_.py b/src/axolotl/utils/callbacks/trackio_.py new file mode 100644 index 000000000..8249321f6 --- /dev/null +++ b/src/axolotl/utils/callbacks/trackio_.py @@ -0,0 +1,44 @@ +"""Trackio module for trainer callbacks""" + +from typing import TYPE_CHECKING + +import trackio +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process +from axolotl.utils.environment import is_package_version_ge +from axolotl.utils.logging import get_logger + +if TYPE_CHECKING: + from axolotl.core.training_args import AxolotlTrainingArguments + +LOG = get_logger(__name__) + + +class SaveAxolotlConfigtoTrackioCallback(TrainerCallback): + """Callback for trackio integration""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if is_main_process(): + try: + if not is_package_version_ge("trackio", "0.11.0"): + LOG.warning( + "Trackio version 0.11.0 or higher is required to save config files. " + "Please upgrade trackio: pip install --upgrade trackio" + ) + return control + + trackio.save(self.axolotl_config_path) + LOG.info("The Axolotl config has been saved to Trackio.") + except (FileNotFoundError, ConnectionError, AttributeError) as err: + LOG.warning(f"Error while saving Axolotl config to Trackio: {err}") + return control diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f2f4a311a..4ef1aff3a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -34,6 +34,7 @@ from axolotl.utils.schemas.integrations import ( MLFlowConfig, OpenTelemetryConfig, RayConfig, + TrackioConfig, WandbConfig, ) from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities @@ -63,6 +64,7 @@ class AxolotlInputConfig( WandbConfig, MLFlowConfig, CometConfig, + TrackioConfig, OpenTelemetryConfig, LISAConfig, GradioConfig, diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 97d675569..dc171c310 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -200,3 +200,23 @@ class OpenTelemetryConfig(BaseModel): "description": "Port for the Prometheus metrics HTTP server" }, ) + + +class TrackioConfig(BaseModel): + """Trackio configuration subset""" + + use_trackio: bool | None = None + trackio_project_name: str | None = Field( + default=None, + json_schema_extra={"description": "Your trackio project name"}, + ) + trackio_run_name: str | None = Field( + default=None, + json_schema_extra={"description": "Set the name of your trackio run"}, + ) + trackio_space_id: str | None = Field( + default=None, + json_schema_extra={ + "description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)" + }, + ) diff --git a/src/axolotl/utils/trackio_.py b/src/axolotl/utils/trackio_.py new file mode 100644 index 000000000..2bddfb972 --- /dev/null +++ b/src/axolotl/utils/trackio_.py @@ -0,0 +1,17 @@ +"""Module for trackio utilities""" + +import os + +from axolotl.utils.dict import DictDefault + + +def setup_trackio_env_vars(cfg: DictDefault): + for key in cfg.keys(): + if key.startswith("trackio_"): + value = cfg.get(key, "") + + if value and isinstance(value, str) and len(value) > 0: + os.environ[key.upper()] = value + + if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0: + cfg.use_trackio = True