Compare commits

...

17 Commits

Author SHA1 Message Date
NanoCode012
53a12282bc fix: log merge command once done 2026-02-14 00:45:01 +07:00
NanoCode012
7271754902 fix: handle plugin logging 2026-02-14 00:40:43 +07:00
NanoCode012
6d5257d92e fix: ignore ds_store 2026-02-14 00:33:53 +07:00
NanoCode012
0e357b5df6 fix: load gemma3 as text only model with dynamic weights 2026-02-14 00:32:48 +07:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
51 changed files with 1302 additions and 88 deletions

3
.gitignore vendored
View File

@@ -193,3 +193,6 @@ out/
# scm auto-versioning # scm auto-versioning
src/axolotl/_version.py src/axolotl/_version.py
# macOS
.DS_Store

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- | | --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | | `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub | | `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes | | `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory | | `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | | `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging | | `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic # type: # linear | dynamic
# factor: # float # factor: # float
# # Whether you are training a 4-bit GPTQ quantized model # # Whether you are training a 4-bit GPTQ quantized model
# gptq: true # gptq: true
# gptq_groupsize: 128 # group size # gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path # push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set. # # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set # dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub # # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model # hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub # # how to push checkpoints to hub
@@ -349,8 +348,6 @@
# # Allow overwrite yml config using from cli # # Allow overwrite yml config using from cli
# strict: # strict:
base_model: ${BASE_MODEL} base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS} base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG} base_model_config: ${BASE_MODEL_CONFIG}
@@ -409,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE} default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH} dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB} push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES} dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY} dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID} hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY} hub_strategy: ${HUB_STRATEGY}

View File

@@ -251,7 +251,6 @@ website:
- docs/models/olmo3.qmd - docs/models/olmo3.qmd
- docs/models/trinity.qmd - docs/models/trinity.qmd
- docs/models/arcee.qmd - docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3" - section: "Ministral3"
contents: contents:
- docs/models/ministral3.qmd - docs/models/ministral3.qmd
@@ -266,6 +265,7 @@ website:
- docs/models/mistral-small.qmd - docs/models/mistral-small.qmd
- docs/models/voxtral.qmd - docs/models/voxtral.qmd
- docs/models/devstral.qmd - docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd - docs/models/llama-4.qmd
- docs/models/llama-2.qmd - docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd - docs/models/qwen3-next.qmd
@@ -320,6 +320,7 @@ website:
- docs/multipack.qmd - docs/multipack.qmd
- docs/mixed_precision.qmd - docs/mixed_precision.qmd
- docs/optimizers.qmd - docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features" - section: "Advanced Features"
contents: contents:

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT. Currently, LoRA kernels are not supported for RLHF training, only SFT.
::: :::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements ## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,6 +19,7 @@ format:
- [Gemma-3n](#sec-gemma-3n) - [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2) - [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl) - [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl) - [Intern-VL](#sec-intern-vl)
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2} ### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
] ]
}, },
{ {

View File

@@ -1,8 +1,7 @@
base_model: google/gemma-3-4b-it base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too plugins:
model_type: Gemma3ForCausalLM - axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
cls_model_config: Gemma3TextConfig
load_in_4bit: true load_in_4bit: true
@@ -30,7 +29,6 @@ lora_model_dir:
sequence_len: 2048 sequence_len: 2048
sample_packing: true sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0 lora_dropout: 0

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-12b-it base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-12b-it base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -1,12 +1,11 @@
base_model: google/gemma-3-27b-it base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
load_in_8bit: false load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -7,6 +7,7 @@ load_in_4bit: false
strict: false strict: false
plugins: plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
- axolotl.integrations.liger.LigerPlugin - axolotl.integrations.liger.LigerPlugin
liger_rope: true liger_rope: true

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
) )

View File

@@ -0,0 +1,225 @@
"""Merge trained text-only Gemma3 weights back into a full multimodal checkpoint.
After training with the Gemma3TextFromMultimodalPlugin, the saved checkpoint
contains only the language model weights (with ``model.language_model.*``
prefix, reversed by transformers v5's key_mapping on save).
This script reconstructs a full ``Gemma3ForConditionalGeneration`` checkpoint by
combining the trained language model weights with the original vision tower and
projector weights from the base multimodal model.
Usage::
python scripts/merge_gemma3_multimodal_weights.py \\
--original-model google/gemma-3-4b-it \\
--trained-model /path/to/trained/output \\
--output-dir /path/to/merged
"""
import argparse
import json
import logging
from pathlib import Path
import torch
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import load_file, save_file
from transformers import AutoConfig
LOG = logging.getLogger(__name__)
def collect_safetensors(model_dir: Path) -> dict[str, torch.Tensor]:
"""Load and merge all safetensors shard files in a directory."""
shard_files = sorted(model_dir.glob("*.safetensors"))
if not shard_files:
raise FileNotFoundError(f"No safetensors files found in {model_dir}")
state_dict: dict[str, torch.Tensor] = {}
for shard in shard_files:
LOG.info("Loading %s", shard.name)
state_dict.update(load_file(str(shard)))
return state_dict
def merge(
original_model: str,
trained_model: str,
output_dir: str,
*,
trust_remote_code: bool = False,
) -> None:
original_path = Path(original_model)
trained_path = Path(trained_model)
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
# 1. Load the original multimodal checkpoint
LOG.info("Loading original multimodal weights from %s", original_model)
if original_path.is_dir():
original_sd = collect_safetensors(original_path)
else:
from huggingface_hub import snapshot_download
cached = Path(
snapshot_download(original_model, allow_patterns=["*.safetensors"])
)
original_sd = collect_safetensors(cached)
# 2. Load trained text-only weights (already reversed to model.language_model.* by
# transformers v5 key_mapping on save)
LOG.info("Loading trained text-only weights from %s", trained_model)
trained_sd = collect_safetensors(trained_path)
# 3. Classify original keys
lang_keys = {k for k in original_sd if k.startswith("model.language_model.")}
vision_keys = {k for k in original_sd if k.startswith("model.vision_tower.")}
projector_keys = {
k for k in original_sd if k.startswith("model.multi_modal_projector.")
}
other_keys = set(original_sd.keys()) - lang_keys - vision_keys - projector_keys
LOG.info(
"Original checkpoint: %d language, %d vision, %d projector, %d other keys",
len(lang_keys),
len(vision_keys),
len(projector_keys),
len(other_keys),
)
# 4. Classify trained keys (reverse mapping on save gives model.language_model.* prefix)
trained_lang_keys = {k for k in trained_sd if k.startswith("model.language_model.")}
trained_other = set(trained_sd.keys()) - trained_lang_keys
LOG.info(
"Trained checkpoint: %d language keys, %d other keys",
len(trained_lang_keys),
len(trained_other),
)
# 5. Build merged state dict
merged: dict[str, torch.Tensor] = {}
# Keep vision tower and projector from original
for key in vision_keys | projector_keys:
merged[key] = original_sd[key]
# Use trained language model weights (overwrite original)
for key in trained_lang_keys:
merged[key] = trained_sd[key]
# For other trained keys (like lm_head.weight), use trained version
for key in trained_other:
merged[key] = trained_sd[key]
# For any original other keys not covered by trained (shouldn't usually happen),
# keep original
for key in other_keys:
if key not in merged:
merged[key] = original_sd[key]
# Check for missing language keys that were in original but not in trained
missing_lang = lang_keys - trained_lang_keys
if missing_lang:
LOG.warning(
"%d language keys in original but not in trained; keeping original: %s",
len(missing_lang),
list(missing_lang)[:5],
)
for key in missing_lang:
merged[key] = original_sd[key]
LOG.info("Merged checkpoint: %d total keys", len(merged))
# 6. Save merged weights (sharded at 50GB, matching transformers default)
LOG.info("Saving merged weights to %s", out_path)
state_dict_split = split_torch_state_dict_into_shards(merged, max_shard_size="50GB")
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {name: merged[name] for name in tensors}
save_file(shard, str(out_path / filename))
if state_dict_split.is_sharded:
index = {
"metadata": {
"total_size": sum(t.numel() * t.element_size() for t in merged.values())
},
"weight_map": state_dict_split.tensor_to_filename,
}
with open(out_path / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2)
LOG.info("Saved %d shards", len(state_dict_split.filename_to_tensors))
# 7. Copy/update config
LOG.info("Writing config.json")
original_config = AutoConfig.from_pretrained(
original_model, trust_remote_code=trust_remote_code
)
# Update text_config fields from trained model's config if available
trained_config_path = trained_path / "config.json"
if trained_config_path.exists():
with open(trained_config_path) as f:
trained_config_dict = json.load(f)
# The trained config is the text sub-config; merge its fields into
# the original composite config's text_config
if hasattr(original_config, "text_config"):
for key, val in trained_config_dict.items():
if key not in ("model_type", "_name_or_path", "architectures"):
if hasattr(original_config.text_config, key):
setattr(original_config.text_config, key, val)
original_config.save_pretrained(out_path)
# 8. Copy tokenizer files from trained model if present
tokenizer_files = list(trained_path.glob("tokenizer*")) + list(
trained_path.glob("special_tokens_map*")
)
if tokenizer_files:
import shutil
for tok_file in tokenizer_files:
shutil.copy2(tok_file, out_path / tok_file.name)
LOG.info("Copied %d tokenizer files", len(tokenizer_files))
LOG.info("Merge complete. Output saved to %s", out_path)
def main():
parser = argparse.ArgumentParser(
description="Merge trained text-only Gemma3 weights back into a multimodal checkpoint."
)
parser.add_argument(
"--original-model",
required=True,
help="HuggingFace model ID or local path to the original multimodal model",
)
parser.add_argument(
"--trained-model",
required=True,
help="Local path to the trained text-only model output directory",
)
parser.add_argument(
"--output-dir",
required=True,
help="Directory to save the merged multimodal checkpoint",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=False,
help="Trust remote code when loading model config",
)
args = parser.parse_args()
merge(
original_model=args.original_model,
trained_model=args.trained_model,
output_dir=args.output_dir,
trust_remote_code=args.trust_remote_code,
)
if __name__ == "__main__":
main()

View File

@@ -409,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict): def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps # save_strategy and save_steps
if self.cfg.save_steps: if self.cfg.save_steps:

View File

@@ -719,6 +719,13 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}") LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = ( supported_classes = (
(PreTrainedModel,) (PreTrainedModel,)
if not is_peft_available() if not is_peft_available()

View File

@@ -126,9 +126,6 @@ class GRPOStrategy:
if trl.use_liger_loss is not None: if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None: if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = ( grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation trl.multi_objective_aggregation
@@ -154,6 +151,8 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes cfg.trl.reward_processing_classes
) )
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs return trainer_kwargs
@@ -164,7 +163,12 @@ class GRPOStrategy:
@classmethod @classmethod
def get_blocklist_args_kwargs(cls) -> list[str]: def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"] return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod @classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
``` ```
## Usage ## Usage
@@ -54,8 +54,8 @@ plugins:
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
- granitemoeshared
- granitemoehybrid - granitemoehybrid
- granitemoeshared
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- internvl - internvl
@@ -80,16 +80,17 @@ plugins:
- phi3 - phi3
- phi4_multimodal - phi4_multimodal
- qwen2 - qwen2
- qwen2_vl
- qwen2_moe - qwen2_moe
- qwen2_vl
- qwen2_5_vl - qwen2_5_vl
- qwen3 - qwen3
- qwen3_moe - qwen3_moe
- qwen3_next
- qwen3_vl - qwen3_vl
- qwen3_vl_moe - qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss - seed_oss
- smollm3
- step3p5
- voxtral - voxtral
## Citation ## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
) )

View File

@@ -0,0 +1,37 @@
# Gemma3 Text-from-Multimodal Plugin
Load a Gemma3 multimodal checkpoint (e.g. `google/gemma-3-4b-it`) directly into `Gemma3ForCausalLM` for text-only training. This bypasses the multimodal trainer path and enables sample packing and other text-specific optimizations.
## How it works
The plugin uses transformers v5's `key_mapping` parameter on `from_pretrained` to remap `model.language_model.*` checkpoint keys to `model.*`, matching what `Gemma3ForCausalLM` expects. Vision tower and projector weights are automatically discarded. On save, transformers reverses the mapping so checkpoints retain the original `model.language_model.*` prefix.
## Usage
Add the plugin to your YAML config:
```yaml
base_model: google/gemma-3-4b-it
plugins:
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
```
See `examples/gemma3/gemma-3-4b-qlora.yml` for a complete example.
## Merging weights back into a multimodal checkpoint
After training, the saved checkpoint contains only the language model weights. To reconstruct a full `Gemma3ForConditionalGeneration` checkpoint (with the original vision tower and projector), use the merge script:
```bash
python scripts/merge_gemma3_multimodal_weights.py \
--original-model google/gemma-3-4b-it \
--trained-model /path/to/trained/output \
--output-dir /path/to/merged
```
This combines:
- **Trained language model weights** from your output checkpoint
- **Original vision tower + projector weights** from the base multimodal model
The merged checkpoint can be loaded as `Gemma3ForConditionalGeneration` for multimodal inference or further training.

View File

@@ -0,0 +1,9 @@
"""Gemma3 integration for loading multimodal checkpoints as text-only models."""
from .args import Gemma3TextFromMultimodalArgs
from .plugin import Gemma3TextFromMultimodalPlugin
__all__ = [
"Gemma3TextFromMultimodalArgs",
"Gemma3TextFromMultimodalPlugin",
]

View File

@@ -0,0 +1,31 @@
"""Pydantic input args for the Gemma3 text-from-multimodal plugin."""
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class Gemma3TextFromMultimodalArgs(BaseModel):
"""Configuration args for loading a Gemma3 multimodal checkpoint as text-only."""
gemma3_text_from_multimodal: bool = True
extract_text_config: bool = False
@model_validator(mode="before")
@classmethod
def set_model_type(cls, data):
if not isinstance(data, dict):
return data
if not data.get("gemma3_text_from_multimodal", True):
return data
if not data.get("model_type"):
LOG.info(
"Gemma3TextFromMultimodalPlugin: auto-setting model_type to Gemma3ForCausalLM"
)
data["model_type"] = "Gemma3ForCausalLM"
return data

View File

@@ -0,0 +1,107 @@
"""Plugin for loading Gemma3 multimodal checkpoints into Gemma3ForCausalLM (text-only).
Uses transformers v5's ``key_mapping`` parameter on ``from_pretrained`` to remap
``model.language_model.*`` keys to ``model.*``, discarding vision tower and projector
weights. On save, transformers automatically reverses the mapping so saved
checkpoints retain the original ``model.language_model.*`` prefix.
"""
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# key_mapping for transformers from_pretrained:
# Remap checkpoint keys matching ^model.language_model -> model
# Vision tower / projector keys won't match any model parameter and are discarded.
GEMMA3_KEY_MAPPING = {"^model.language_model": "model"}
class Gemma3TextFromMultimodalPlugin(BasePlugin):
"""Load a Gemma3 multimodal checkpoint as a text-only Gemma3ForCausalLM.
Hooks
-----
register(cfg)
Runs before config validation. Sets the ``_extract_text_config`` flag,
ensures ``model_type`` is ``Gemma3ForCausalLM``, and injects
``key_mapping`` into ``model_kwargs`` so that ``from_pretrained`` remaps
``model.language_model.*`` → ``model.*``.
pre_model_load(cfg)
Runs after config validation/normalization but before model instantiation.
Validates that ``model_config_type`` is ``gemma3_text`` and
``is_multimodal`` is False (confirming that ``_extract_text_config``
worked correctly).
"""
def get_input_args(self) -> str:
return "axolotl.integrations.gemma3.Gemma3TextFromMultimodalArgs"
def register(self, cfg: dict):
"""Set up config for multimodal → text-only loading.
This runs before Pydantic validation, so ``cfg`` is a raw dict.
"""
if not cfg.get("gemma3_text_from_multimodal", True):
raise ValueError(
"Gemma3TextFromMultimodalPlugin: disabled via config, but plugin selected"
)
# Flag for load_model_config() to extract the text sub-config
cfg["extract_text_config"] = True
# Ensure model_type is set for the text-only model class
if not cfg.get("model_type"):
cfg["model_type"] = "Gemma3ForCausalLM"
# Inject key_mapping into model_kwargs so from_pretrained remaps weights
model_kwargs = cfg.setdefault("model_kwargs", {})
model_kwargs["key_mapping"] = GEMMA3_KEY_MAPPING
def pre_model_load(self, cfg):
"""Validate that config extraction worked before model instantiation."""
if not getattr(cfg, "gemma3_text_from_multimodal", True):
return
if cfg.model_config_type != "gemma3_text":
LOG.warning(
"Gemma3TextFromMultimodalPlugin: expected model_config_type='gemma3_text' "
"but got '%s'. The text config extraction may not have worked.",
cfg.model_config_type,
)
if cfg.is_multimodal or cfg.processor_type:
raise ValueError(
"Multimodal mode is enabled (processor_type set), but "
"Gemma3TextFromMultimodalPlugin enabled. "
"Please disable one of the two."
)
def post_train(self, cfg, model):
"""Log merge command after training completes."""
if cfg.adapter:
LOG.info(
"Adapter training detected. To reconstruct the multimodal checkpoint:\n"
" 1. Merge adapter weights into the text-only base model:\n"
" axolotl merge_lora <your_config.yml>\n"
" 2. Then merge the resulting full model back into the multimodal checkpoint:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s/merged \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)
else:
LOG.info(
"To merge trained weights back into the multimodal checkpoint, run:\n"
" python scripts/merge_gemma3_multimodal_weights.py \\\n"
" --original-model %s \\\n"
" --trained-model %s \\\n"
" --output-dir %s/multi-modal/merged",
cfg.base_model,
cfg.output_dir,
cfg.output_dir,
)

View File

@@ -338,7 +338,12 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility. # we need to convert them back to fp16/bf16 for flash-attn compatibility.
( (
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) (
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
) )
or ( or (
@@ -612,6 +617,10 @@ class ModelLoader:
elif self.cfg.sdp_attention: elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa" self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa" self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention: elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager" self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager" self.model_config._attn_implementation = "eager"

View File

@@ -96,6 +96,7 @@ class PatchManager:
# self._apply_flex_attention_patches() # self._apply_flex_attention_patches()
self._apply_flash_attention_patches() self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch() self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches() self._apply_fsdp_patches()
self._apply_adapter_patches() self._apply_adapter_patches()
self._apply_model_specific_patches() self._apply_model_specific_patches()
@@ -201,6 +202,13 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs) patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self): def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures.""" """Apply patches specific to model architectures."""
if ( if (

View File

@@ -204,6 +204,13 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
check_model_config(cfg, model_config) check_model_config(cfg, model_config)
# Extract text config from composite config when explicitly requested
# (set by plugins like Gemma3TextFromMultimodalPlugin)
if getattr(cfg, "extract_text_config", False) and hasattr(
model_config, "get_text_config"
):
model_config = model_config.get_text_config()
return model_config return model_config

View File

@@ -0,0 +1,211 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise ValueError( raise ValueError(
f"Could not import attention class for model_type: {model_type}. " f"Axolotl could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e

View File

@@ -485,6 +485,58 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -501,10 +553,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm, "image_resize_algorithm": image_resize_algorithm,
} }
if chat_template_type in [None, "tokenizer_default"] and hasattr( if chat_template_type in [None, "tokenizer_default"]:
processor.tokenizer, "chat_template" tokenizer = getattr(processor, "tokenizer", processor)
): if hasattr(tokenizer, "chat_template"):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template processing_kwargs["chat_template"] = tokenizer.chat_template
if chat_template_type == "qwen2_vl": if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy( return Qwen2VLProcessingStrategy(
@@ -533,6 +585,15 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy( return Mistral3ProcessingStrategy(
**processing_kwargs, **processing_kwargs,
) )
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor): if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy( return InternVLProcessingStrategy(

View File

@@ -153,13 +153,27 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict: def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history.""" """Extract last loss, learning_rate, grad_norm, and token metrics from log history."""
if not state.log_history: if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0} return {
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
last_log = state.log_history[-1] last_log = state.log_history[-1]
return { return {
"loss": last_log.get("loss", 0), "loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0), "learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0), "grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
} }

View File

@@ -155,6 +155,10 @@ def send_errors(func: Callable) -> Callable:
}, },
) )
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise raise
return wrapper return wrapper

View File

@@ -5,7 +5,6 @@ import importlib
import logging import logging
import os import os
import platform import platform
import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -20,21 +19,6 @@ LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com" POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes # NOTE: Need to keep these up to date with any config schema changes
@@ -46,8 +30,8 @@ FIELDS_TO_REDACT = {
"resume_from_checkpoint", "resume_from_checkpoint",
"hub_model_id", "hub_model_id",
} }
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"} PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_", "trackio_", "swanlab_"}
PATH_INDICATORS = {"path", "dir"} PATH_INDICATORS = {"path", "dir", "data_files"}
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
RELEVANT_PACKAGES = { RELEVANT_PACKAGES = {
@@ -183,11 +167,6 @@ class TelemetryManager:
"false", "false",
"true", "true",
): ):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True return True
# Only rank 0 will send telemetry # Only rank 0 will send telemetry

View File

@@ -31,3 +31,10 @@ organizations:
- "mistral-community" - "mistral-community"
- "llava-hf" - "llava-hf"
- "ByteDance-Seed" - "ByteDance-Seed"
- "ACE-Step"
- "openbmb"
- "MiniMaxAI"
- "stepfun-ai"
- "internlm"
- "katanemo"
- "XiaomiMiMo"

View File

@@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
tokens = getattr(state, "tokens", None) tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens: if not (tokens and "trainable_tokens" in tokens):
step_time = time.perf_counter() - self.start_time return
num_tokens_per_device = tokens["trainable_tokens"].clone() step_time = time.perf_counter() - self.start_time
# non data parallel groups have duplicated tokens, so we avoid double-counting if step_time <= 0:
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size return
state.last_tokens_per_second = num_tokens_per_device / step_time
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
if torch.distributed.is_initialized():
dp_size = max(
1, torch.distributed.get_world_size() // self.non_data_parallel_size
)
num_tokens = num_tokens / dp_size
state.last_tokens_per_second = num_tokens / step_time
def on_log( def on_log(
self, self,

View File

@@ -218,6 +218,9 @@ class SequenceParallelContextManager:
self.original_seq_len = 0 self.original_seq_len = 0
self.pad_len = 0 self.pad_len = 0
# Track local valid token count for eval loss correction across CP ranks
self._local_valid_tokens: torch.Tensor | None = None
# Create a partially applied version of the apply_sequence_parallelism function # Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial( self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism, apply_sequence_parallelism,
@@ -270,6 +273,18 @@ class SequenceParallelContextManager:
self.apply_sequence_parallelism(updated_kwargs) self.apply_sequence_parallelism(updated_kwargs)
) )
# Track local valid tokens for eval loss correction
if "labels" in updated_kwargs and not self.models[0].training:
self._local_valid_tokens = (
(updated_kwargs["labels"] != -100).sum().float()
)
# Strip num_items_in_batch during eval so the model uses
# reduction='mean', allowing the post-hook weighted all-reduce
# formula (loss * local_valid) to correctly recover the loss sum
updated_kwargs.pop("num_items_in_batch", None)
else:
self._local_valid_tokens = None
return remaining_args, updated_kwargs return remaining_args, updated_kwargs
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
@@ -287,6 +302,44 @@ class SequenceParallelContextManager:
return output return output
# Post-hook to correct eval loss via weighted all-reduce across CP ranks
def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:
if self._local_valid_tokens is None:
return output
if not hasattr(output, "loss") or output.loss is None:
return output
local_valid = self._local_valid_tokens.to(output.loss.device)
loss = output.loss.detach().clone()
# Handle rank with zero valid tokens (loss is NaN)
if local_valid.item() == 0:
weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)
else:
weighted_loss = loss * local_valid
total_valid = local_valid.clone()
dist.all_reduce(
weighted_loss,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
dist.all_reduce(
total_valid,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
if total_valid.item() > 0:
output["loss"] = (weighted_loss / total_valid).squeeze()
else:
output["loss"] = torch.tensor(
float("nan"), device=loss.device, dtype=loss.dtype
)
self._local_valid_tokens = None
return output
# Register hooks # Register hooks
for model in self.models: for model in self.models:
self.hook_handles.append( self.hook_handles.append(
@@ -298,6 +351,10 @@ class SequenceParallelContextManager:
self.hook_handles.append( self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook) model.register_forward_hook(sequence_parallel_post_hook)
) )
# Always register eval loss correction hook
self.hook_handles.append(
model.register_forward_hook(eval_loss_correction_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor.""" """Gather sharded outputs from all ranks and reconstruct the full tensor."""

View File

@@ -2,11 +2,19 @@
import os import os
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_default_process_count(): def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc) return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
LOG.warning(
"AXOLOTL_DATASET_PROCESSES and `dataset_processes` are deprecated and will be "
"removed in a future version. Please use `dataset_num_proc` instead."
)
return int(axolotl_dataset_processes) return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count) return int(runpod_cpu_count)

View File

@@ -86,15 +86,15 @@ class HFMistralTokenizer(MistralCommonBackend):
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
**kwargs, **kwargs,
) -> str | list[int]: ) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg""" """Patched fn to handle setting test mode, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral # pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None) kwargs.pop("real_last_index", None)
kwargs.pop("add_special_tokens", None)
try: try:
if add_generation_prompt: if add_generation_prompt:
self._set_mode(ValidationMode.serving) self._set_mode(ValidationMode.test)
kwargs["continue_final_message"] = True
out = super().apply_chat_template(conversation, **kwargs) out = super().apply_chat_template(conversation, **kwargs)

View File

@@ -609,6 +609,12 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "Whether to use bettertransformers"}, json_schema_extra={"description": "Whether to use bettertransformers"},
) )
sage_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
},
)
eager_attention: bool | None = None eager_attention: bool | None = None
@@ -1120,6 +1126,27 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"):
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options""" """Wrapper to valdiate GPU capabilities with the configured options"""
@@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data return data
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
if (
data.get("sage_attention")
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
):
raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_multigpu_unsloth(cls, data): def check_multigpu_unsloth(cls, data):
@@ -1229,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
): ):
return data return data
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0: if data.get("lora_dropout") != 0:
return data return data

View File

@@ -120,6 +120,12 @@ class ModelOutputConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "how to push checkpoints to hub"}, json_schema_extra={"description": "how to push checkpoints to hub"},
) )
hub_revision: str | None = Field(
default=None,
json_schema_extra={
"description": "branch/revision to push to on hub (default: main)"
},
)
save_safetensors: bool | None = Field( save_safetensors: bool | None = Field(
default=True, default=True,
json_schema_extra={ json_schema_extra={

View File

@@ -166,9 +166,10 @@ class AttentionValidationMixin:
fields = ( fields = (
"xformers_attention", "xformers_attention",
"sdp_attention", "sdp_attention",
"s2_attention", # "s2_attention", # requires both FA and this to be enabled
"flash_attention", "flash_attention",
"flex_attention", "flex_attention",
"sage_attention",
) )
non_empty_count = sum(1 for field in fields if data.get(field)) non_empty_count = sum(1 for field in fields if data.get(field))
@@ -185,9 +186,10 @@ class AttentionValidationMixin:
and not data.get("sdp_attention") and not data.get("sdp_attention")
and not data.get("flex_attention") and not data.get("flex_attention")
and not data.get("xformers_attention") and not data.get("xformers_attention")
and not data.get("sage_attention")
): ):
LOG.warning( LOG.warning(
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." "sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
) )
return data return data
@@ -688,6 +690,21 @@ class LoRAValidationMixin:
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("trust_remote_code"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with trust_remote_code. Please disable trust_remote_code "
"or explicitly set lora_*_kernel to false."
)
return data
class RLValidationMixin: class RLValidationMixin:
"""Validation methods related to RL training configuration.""" """Validation methods related to RL training configuration."""

View File

@@ -247,7 +247,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=F
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"] drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3", "gemma3_text"]
if drop_attn_mask: if drop_attn_mask:
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")

View File

@@ -79,7 +79,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800, "ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25, "ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False, "ddp_broadcast_buffers": False,
"dataset_processes": 4, "dataset_num_proc": 4,
} }
) )

View File

@@ -30,7 +30,7 @@ class TestStreamingDatasets:
"sample_packing": sample_packing, "sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing, "pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000, "streaming_multipack_buffer_size": 10000,
"dataset_processes": 1, "dataset_num_proc": 1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },

View File

@@ -118,20 +118,6 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
assert not manager.enabled assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
assert any(
"Telemetry is now enabled by default" in str(call)
for call in mock_warning.call_args_list
)
def test_is_whitelisted(telemetry_manager_class, mock_whitelist): def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality""" """Test org whitelist functionality"""
with ( with (

View File

@@ -90,3 +90,62 @@ class TestLoRAConfigValidation:
} }
) )
validate_config(invalid_config) validate_config(invalid_config)
@pytest.mark.parametrize(
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
"""Test that lora kernels are incompatible with trust_remote_code"""
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
invalid_config = DictDefault(
{
"adapter": "lora",
kernel_field: True,
"trust_remote_code": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_lora_kernels_trust_remote_code_false(self):
"""Test that lora kernels work when trust_remote_code is false"""
# Test with trust_remote_code=False, lora kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"trust_remote_code": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_mlp_kernel"] is True
assert result["lora_qkv_kernel"] is True
assert result["lora_o_kernel"] is True
# Test with trust_remote_code=None (unset), kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_qkv_kernel": True,
"trust_remote_code": None,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None