Compare commits

..

20 Commits

Author SHA1 Message Date
NanoCode012
970b2a6f2f feat: test for config validation and BC for new peft weight dtype 2026-02-16 21:26:04 +07:00
NanoCode012
1f7f5e7c26 feat: handle lora kernels compat with torchao 2026-02-16 21:25:50 +07:00
NanoCode012
60c0a828cc feat: add torchao's int4, nf4, int8 2026-02-16 21:25:24 +07:00
NanoCode012
4f1b5ad29f fix: clarify how to use lm_eval plugin (#3404) [skip ci] 2026-02-15 07:52:30 -05:00
NanoCode012
d6a2532dd7 feat(doc): clarify how to use scattermoe (#3408) [skip ci]
* feat(doc): clarify how to use scattermoe

* chore: fix wording
2026-02-15 07:51:28 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
Wing Lian
236dad3bb7 set 0.15.0.dev0 version (#3380) 2026-01-30 21:28:01 -05:00
50 changed files with 1471 additions and 105 deletions

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -349,8 +348,6 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -409,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}

View File

@@ -1 +1 @@
0.14.0
0.15.0.dev0

View File

@@ -251,7 +251,6 @@ website:
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
@@ -266,6 +265,7 @@ website:
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
@@ -320,6 +320,7 @@ website:
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -210,6 +210,8 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -218,7 +220,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
### delinearize-llama4

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,6 +19,7 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
]
},
{

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
)

View File

@@ -409,6 +409,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:

View File

@@ -719,6 +719,13 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()

View File

@@ -126,9 +126,6 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
@@ -154,6 +151,8 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs
@@ -164,7 +163,12 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
```
## Usage
@@ -54,8 +54,8 @@ plugins:
- gpt_oss
- granite
- granitemoe
- granitemoeshared
- granitemoehybrid
- granitemoeshared
- hunyuan_v1_dense
- hunyuan_v1_moe
- internvl
@@ -80,16 +80,17 @@ plugins:
- phi3
- phi4_multimodal
- qwen2
- qwen2_vl
- qwen2_moe
- qwen2_vl
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_next
- qwen3_vl
- qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss
- smollm3
- step3p5
- voxtral
## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
)
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type: str,
model_type_to_patch: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
if model_type_to_patch not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
"Setting up generic cce patch for model type: %s", model_type_to_patch
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -0,0 +1,44 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -6,6 +6,12 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -16,9 +22,50 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,6 +13,21 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -108,7 +123,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
model=get_model_path(cfg),
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .quantize import dequantize_weight
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
@@ -46,6 +46,12 @@ def get_lora_parameters(
W = base_layer.weight
b = base_layer.bias
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
# attribute access to the underlying tensor subclass, so torchao methods like
# .dequantize() or .get_original_weight() would not be visible.
if isinstance(W, DTensor):
W = W.full_tensor()
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
@@ -86,6 +92,7 @@ def matmul_lora(
B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None,
transpose: bool = True,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
@@ -98,12 +105,15 @@ def matmul_lora(
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
transpose: If True (default), transpose W before matmul (forward path).
Set to False for backward paths where W is already in the correct layout.
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
is_quantized = W_quant is not None or type(W) is not torch.Tensor
W = dequantize_weight(W, W_quant, transpose=transpose)
reshape = False
if X.dim() == 3:
@@ -112,7 +122,7 @@ def matmul_lora(
reshape = True
out = torch.matmul(X, W, out=out)
if W_quant is not None:
if is_quantized:
del W
if A is not None:
@@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function):
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
# Down projection (backward: no transpose needed, W is already [out, in])
grad_down = matmul_lora(
grad_output,
down_weight.t(),
down_weight,
None,
down_quant,
down_B,
down_A,
down_scale,
transpose=False,
)
# Activation backward
@@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function):
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
@@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant)
gate_weight = dequantize_weight(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
del gate_weight
@@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function):
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
q_weight_t = dequantize_weight(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
@@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
k_weight_t = dequantize_weight(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
@@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
v_weight_t = dequantize_weight(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
@@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function):
d_B = s * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
W = dequantize_weight(W, W_quant, transpose=True)
dX = dY @ W.t()
del W

View File

@@ -146,3 +146,43 @@ def dequantize(
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out
def dequantize_weight(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
transpose: bool = False,
) -> torch.Tensor:
"""Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in ``dequantize``.
Args:
W: Quantized weight tensor ``[out_features, in_features]``.
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
transpose: If True, return ``[in_features, out_features]``.
Returns:
Dequantized float tensor, optionally transposed.
"""
# torchao path: tensor subclass with embedded quantization state
if quant_state is None and type(W) is not torch.Tensor:
result = None
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
if hasattr(W, "get_original_weight"):
result = W.get_original_weight()
else:
# AffineQuantizedTensor (INT4, etc.)
try:
result = W.dequantize()
except (TypeError, RuntimeError):
pass
if result is not None:
return result.t() if transpose else result
# bnb path: transpose input before the CUDA kernel (existing convention)
if transpose:
return dequantize(W.t(), quant_state)
return dequantize(W, quant_state)

View File

@@ -23,6 +23,7 @@ from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import TorchAOQuantDType
LOG = get_logger(__name__)
@@ -134,11 +135,13 @@ def load_lora(
rank = int(os.environ.get("LOCAL_RANK", 0))
is_torchao = cfg.peft and cfg.peft.backend == "torchao"
if (
cfg.fsdp_config
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_meta_for_peft(model)
@@ -146,6 +149,15 @@ def load_lora(
if cfg.peft_autocast_adapter_dtype is not None:
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
# Patch PEFT's torchao dispatch before any model creation/loading.
# Must happen before both get_peft_model and PeftModel.from_pretrained,
# as both trigger LoRA layer dispatch that would fail for INT4/NF4 weights.
# INT8 is natively supported by PEFT's TorchaoLoraLinear, so skip the patch.
if is_torchao and cfg.peft.weight_dtype != TorchAOQuantDType.int8:
from axolotl.monkeypatch.peft.utils import patch_peft_torchao_dispatch
patch_peft_torchao_dispatch()
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
if cfg.lora_on_cpu:
@@ -172,6 +184,7 @@ def load_lora(
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_peft_meta_for_training(model)

View File

@@ -158,6 +158,15 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@property
def is_torchao_qlora(self):
"""Property that determines if torchao backend is used for QLoRA."""
return (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
)
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
@@ -338,7 +347,12 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
(
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled
)
or (
@@ -486,8 +500,9 @@ class ModelLoader:
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
if self.is_fsdp_enabled:
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
if self.is_qlora_and_fsdp_enabled:
# For QLoRA + FSDP with bnb, we still need to set device_map for proper initialization
# torchao tensors work natively with FSDP2, no device_map override needed
if self.is_qlora_and_fsdp_enabled and not self.is_torchao_qlora:
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
@@ -556,6 +571,44 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
and not self.cfg.merge_lora
):
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchAOQuantDType
weight_dtype = self.cfg.peft.weight_dtype
if weight_dtype == TorchAOQuantDType.int4:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int4_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.int8:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int8_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.nf4:
from torchao.dtypes._nf4tensor_api import NF4WeightOnlyConfig
block_size = self.cfg.peft.group_size or 64
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type=NF4WeightOnlyConfig(
block_size=block_size,
scaler_block_size=256,
),
)
else:
raise ValueError(
f"Unsupported torchao weight_dtype for QLoRA: {weight_dtype}. "
"Supported: int4, int8, nf4"
)
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
@@ -612,6 +665,10 @@ class ModelLoader:
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
@@ -851,6 +908,10 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
# torchao quantized models don't use Params4bit and don't need kbit preparation
if self.is_torchao_qlora:
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -96,6 +96,7 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -201,6 +202,13 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -340,10 +348,12 @@ class PatchManager:
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
is_torchao = self.cfg.peft and self.cfg.peft.backend == "torchao"
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and not is_torchao
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,

View File

@@ -0,0 +1,211 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Axolotl could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Error: {str(e)}"
) from e

View File

@@ -78,3 +78,30 @@ def patch_peft_prep_code():
axolotl.loaders.model.prepare_model_for_kbit_training = (
fixed_prepare_model_for_kbit_training
)
def patch_peft_torchao_dispatch():
"""Skip PEFT's TorchaoLoraLinear for non-INT8 torchao weights.
PEFT's dispatch_torchao() matches AffineQuantizedTensor but then errors in
_check_dtype_supported() because it only allows INT8. Our LoRA kernels handle
dequantization explicitly, so we bypass PEFT's torchao dispatch entirely and
let it fall back to standard Linear LoRA layers.
"""
try:
from peft.tuners.lora import torchao as peft_torchao
except ImportError:
LOG.warning("Could not import peft.tuners.lora.torchao for patching")
return
if getattr(peft_torchao, "_axolotl_patched", False):
return
def patched_dispatch(target, adapter_name, lora_config, **kwargs):
# Return None so PEFT falls back to standard Linear LoRA layers.
# Our LoRA kernels handle torchao dequantization explicitly.
return None
peft_torchao.dispatch_torchao = patched_dispatch
peft_torchao._axolotl_patched = True
LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear")

View File

@@ -485,6 +485,58 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -501,10 +553,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type in [None, "tokenizer_default"]:
tokenizer = getattr(processor, "tokenizer", processor)
if hasattr(tokenizer, "chat_template"):
processing_kwargs["chat_template"] = tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
@@ -533,6 +585,15 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy(
**processing_kwargs,
)
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy(

View File

@@ -153,13 +153,27 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history."""
"""Extract last loss, learning_rate, grad_norm, and token metrics from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
return {
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
}

View File

@@ -155,6 +155,10 @@ def send_errors(func: Callable) -> Callable:
},
)
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise
return wrapper

View File

@@ -5,7 +5,6 @@ import importlib
import logging
import os
import platform
import time
import uuid
from pathlib import Path
from typing import Any
@@ -20,21 +19,6 @@ LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes
@@ -46,8 +30,8 @@ FIELDS_TO_REDACT = {
"resume_from_checkpoint",
"hub_model_id",
}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir"}
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_", "trackio_", "swanlab_"}
PATH_INDICATORS = {"path", "dir", "data_files"}
# pylint: disable=duplicate-code
RELEVANT_PACKAGES = {
@@ -183,11 +167,6 @@ class TelemetryManager:
"false",
"true",
):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True
# Only rank 0 will send telemetry

View File

@@ -31,3 +31,10 @@ organizations:
- "mistral-community"
- "llava-hf"
- "ByteDance-Seed"
- "ACE-Step"
- "openbmb"
- "MiniMaxAI"
- "stepfun-ai"
- "internlm"
- "katanemo"
- "XiaomiMiMo"

View File

@@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs,
): # pylint: disable=unused-argument
tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens:
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = tokens["trainable_tokens"].clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
if not (tokens and "trainable_tokens" in tokens):
return
step_time = time.perf_counter() - self.start_time
if step_time <= 0:
return
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
if torch.distributed.is_initialized():
dp_size = max(
1, torch.distributed.get_world_size() // self.non_data_parallel_size
)
num_tokens = num_tokens / dp_size
state.last_tokens_per_second = num_tokens / step_time
def on_log(
self,

View File

@@ -218,6 +218,9 @@ class SequenceParallelContextManager:
self.original_seq_len = 0
self.pad_len = 0
# Track local valid token count for eval loss correction across CP ranks
self._local_valid_tokens: torch.Tensor | None = None
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
@@ -270,6 +273,18 @@ class SequenceParallelContextManager:
self.apply_sequence_parallelism(updated_kwargs)
)
# Track local valid tokens for eval loss correction
if "labels" in updated_kwargs and not self.models[0].training:
self._local_valid_tokens = (
(updated_kwargs["labels"] != -100).sum().float()
)
# Strip num_items_in_batch during eval so the model uses
# reduction='mean', allowing the post-hook weighted all-reduce
# formula (loss * local_valid) to correctly recover the loss sum
updated_kwargs.pop("num_items_in_batch", None)
else:
self._local_valid_tokens = None
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
@@ -287,6 +302,44 @@ class SequenceParallelContextManager:
return output
# Post-hook to correct eval loss via weighted all-reduce across CP ranks
def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:
if self._local_valid_tokens is None:
return output
if not hasattr(output, "loss") or output.loss is None:
return output
local_valid = self._local_valid_tokens.to(output.loss.device)
loss = output.loss.detach().clone()
# Handle rank with zero valid tokens (loss is NaN)
if local_valid.item() == 0:
weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)
else:
weighted_loss = loss * local_valid
total_valid = local_valid.clone()
dist.all_reduce(
weighted_loss,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
dist.all_reduce(
total_valid,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
if total_valid.item() > 0:
output["loss"] = (weighted_loss / total_valid).squeeze()
else:
output["loss"] = torch.tensor(
float("nan"), device=loss.device, dtype=loss.dtype
)
self._local_valid_tokens = None
return output
# Register hooks
for model in self.models:
self.hook_handles.append(
@@ -298,6 +351,10 @@ class SequenceParallelContextManager:
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
# Always register eval loss correction hook
self.hook_handles.append(
model.register_forward_hook(eval_loss_correction_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""

View File

@@ -2,11 +2,19 @@
import os
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
LOG.warning(
"AXOLOTL_DATASET_PROCESSES and `dataset_processes` are deprecated and will be "
"removed in a future version. Please use `dataset_num_proc` instead."
)
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count)

View File

@@ -86,15 +86,15 @@ class HFMistralTokenizer(MistralCommonBackend):
add_generation_prompt: bool = False,
**kwargs,
) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
"""Patched fn to handle setting test mode, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None)
kwargs.pop("add_special_tokens", None)
try:
if add_generation_prompt:
self._set_mode(ValidationMode.serving)
kwargs["continue_final_message"] = True
self._set_mode(ValidationMode.test)
out = super().apply_chat_template(conversation, **kwargs)

View File

@@ -609,6 +609,12 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "Whether to use bettertransformers"},
)
sage_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
},
)
eager_attention: bool | None = None
@@ -1120,6 +1126,27 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"):
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options"""
@@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
if (
data.get("sage_attention")
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
):
raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation."
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
@@ -1229,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
return data
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data

View File

@@ -8,6 +8,7 @@ import torch
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
nf4 = "nf4"
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
@@ -16,6 +17,8 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str == "nf4":
return TorchAOQuantDType.nf4
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":

View File

@@ -120,6 +120,12 @@ class ModelOutputConfig(BaseModel):
default=None,
json_schema_extra={"description": "how to push checkpoints to hub"},
)
hub_revision: str | None = Field(
default=None,
json_schema_extra={
"description": "branch/revision to push to on hub (default: main)"
},
)
save_safetensors: bool | None = Field(
default=True,
json_schema_extra={

View File

@@ -1,9 +1,12 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import validate_ao_dtype
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -15,7 +18,7 @@ class LoftQConfig(BaseModel):
class PeftConfig(BaseModel):
"""peftq configuration subset"""
"""PEFT configuration subset"""
loftq_config: LoftQConfig | None = Field(
default=None,
@@ -23,6 +26,29 @@ class PeftConfig(BaseModel):
"description": "Configuration options for loftq initialization for LoRA"
},
)
backend: Literal["bnb", "torchao"] | None = Field(
default=None,
json_schema_extra={
"description": "Quantization backend for QLoRA. 'bnb' for bitsandbytes (default), 'torchao' for torchao."
},
)
weight_dtype: TorchAOQuantDType | None = Field(
default=None,
json_schema_extra={
"description": "Weight quantization dtype (int4, int8, or nf4). Also used with bnb backend to auto-configure quantization."
},
)
group_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Group size for quantization. Defaults to 128 for int4, 64 for nf4."
},
)
@field_validator("weight_dtype", mode="before")
@classmethod
def validate_weight_dtype(cls, v):
return validate_ao_dtype(v)
class LoraConfig(BaseModel):
@@ -156,6 +182,56 @@ class LoraConfig(BaseModel):
merge_lora: bool | None = None
@model_validator(mode="before")
@classmethod
def auto_detect_qlora(cls, data):
"""Auto-set adapter type and quantization flags from peft config.
When peft.backend and peft.weight_dtype are set, this infers the correct
adapter type and internal flags (load_in_4bit, load_in_8bit) so users
don't need to set them manually.
"""
peft = data.get("peft")
if not isinstance(peft, dict):
return data
backend = peft.get("backend")
weight_dtype = peft.get("weight_dtype")
# Validate: weight_dtype requires backend
if weight_dtype and not backend:
raise ValueError(
"peft.backend is required when peft.weight_dtype is set. "
"Use 'torchao' or 'bnb'."
)
if not weight_dtype:
return data
adapter = data.get("adapter")
if backend == "torchao":
# torchao: any quantized weight_dtype means qlora
if adapter == "lora":
data["adapter"] = "qlora"
elif backend == "bnb":
if weight_dtype == "nf4":
# bnb nf4 = qlora with load_in_4bit
if adapter == "lora":
data["adapter"] = "qlora"
data.setdefault("load_in_4bit", True)
elif weight_dtype == "int8":
# bnb int8 = lora with load_in_8bit
data.setdefault("load_in_8bit", True)
else:
raise ValueError(
f"peft.weight_dtype '{weight_dtype}' is not supported with bnb backend. "
"Supported: nf4, int8."
)
return data
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
@@ -173,6 +249,8 @@ class LoraConfig(BaseModel):
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
is_torchao = self.peft and self.peft.backend == "torchao"
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
@@ -184,7 +262,20 @@ class LoraConfig(BaseModel):
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
elif is_torchao:
# torchao backend: validate torchao-specific requirements
if self.load_in_4bit or self.load_in_8bit:
raise ValueError(
"load_in_4bit/load_in_8bit are for bitsandbytes. "
"With peft.backend: torchao, quantization is handled by torchao."
)
if not self.peft.weight_dtype:
raise ValueError(
"peft.weight_dtype is required when peft.backend is 'torchao'"
)
else:
# Default bnb path
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")

View File

@@ -16,6 +16,8 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v == "nf4":
return TorchAOQuantDType.nf4
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":

View File

@@ -166,9 +166,10 @@ class AttentionValidationMixin:
fields = (
"xformers_attention",
"sdp_attention",
"s2_attention",
# "s2_attention", # requires both FA and this to be enabled
"flash_attention",
"flex_attention",
"sage_attention",
)
non_empty_count = sum(1 for field in fields if data.get(field))
@@ -185,9 +186,10 @@ class AttentionValidationMixin:
and not data.get("sdp_attention")
and not data.get("flex_attention")
and not data.get("xformers_attention")
and not data.get("sage_attention")
):
LOG.warning(
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
)
return data
@@ -688,6 +690,21 @@ class LoRAValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("trust_remote_code"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with trust_remote_code. Please disable trust_remote_code "
"or explicitly set lora_*_kernel to false."
)
return data
class RLValidationMixin:
"""Validation methods related to RL training configuration."""

View File

@@ -79,7 +79,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
"dataset_processes": 4,
"dataset_num_proc": 4,
}
)

View File

@@ -3,7 +3,7 @@
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
from axolotl.kernels.quantize import dequantize, dequantize_weight
def test_dequantize_null_state():
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out)
assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())

View File

@@ -30,7 +30,7 @@ class TestStreamingDatasets:
"sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000,
"dataset_processes": 1,
"dataset_num_proc": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -118,20 +118,6 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
assert any(
"Telemetry is now enabled by default" in str(call)
for call in mock_warning.call_args_list
)
def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality"""
with (

View File

@@ -3,6 +3,14 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
BASE_CFG = {
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
class TestLoRAConfigValidation:
"""Test suite for LoRA/QLoRA configuration validation"""
@@ -90,3 +98,254 @@ class TestLoRAConfigValidation:
}
)
validate_config(invalid_config)
@pytest.mark.parametrize(
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
"""Test that lora kernels are incompatible with trust_remote_code"""
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
invalid_config = DictDefault(
{
"adapter": "lora",
kernel_field: True,
"trust_remote_code": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_lora_kernels_trust_remote_code_false(self):
"""Test that lora kernels work when trust_remote_code is false"""
# Test with trust_remote_code=False, lora kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"trust_remote_code": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_mlp_kernel"] is True
assert result["lora_qkv_kernel"] is True
assert result["lora_o_kernel"] is True
# Test with trust_remote_code=None (unset), kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_qkv_kernel": True,
"trust_remote_code": None,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None
class TestTorchaoQLoRAConfigValidation:
"""Test suite for torchao QLoRA auto-detection and validation"""
# --- Auto-detection: torchao ---
@pytest.mark.parametrize("weight_dtype", ["int4", "int8", "nf4"])
def test_torchao_auto_detect_from_lora(self, weight_dtype):
"""adapter: lora + peft.backend: torchao auto-upgrades to qlora"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "torchao", "weight_dtype": weight_dtype},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["peft"]["backend"] == "torchao"
def test_torchao_explicit_qlora(self):
"""adapter: qlora + peft.backend: torchao works directly"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
# --- Auto-detection: bnb ---
def test_bnb_nf4_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: nf4 → qlora + load_in_4bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_bnb_int8_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: int8 → lora + load_in_8bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int8"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_bnb_nf4_explicit_qlora_auto_sets_load_in_4bit(self):
"""adapter: qlora + peft.backend: bnb + weight_dtype: nf4 auto-sets load_in_4bit"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
# --- Backward compat ---
def test_old_style_qlora_unchanged(self):
"""Old-style adapter: qlora + load_in_4bit: true still works"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_old_style_lora_8bit_unchanged(self):
"""Old-style adapter: lora + load_in_8bit: true still works"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_8bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_plain_lora_unchanged(self):
"""adapter: lora without peft block stays as lora"""
cfg = DictDefault(
{
"adapter": "lora",
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
# --- Validation errors ---
def test_torchao_with_load_in_4bit_errors(self):
"""peft.backend: torchao + load_in_4bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_with_load_in_8bit_errors(self):
"""peft.backend: torchao + load_in_8bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_8bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_without_weight_dtype_errors(self):
"""peft.backend: torchao without weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.weight_dtype is required"):
validate_config(cfg)
def test_weight_dtype_without_backend_errors(self):
"""peft.weight_dtype without peft.backend errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.backend is required"):
validate_config(cfg)
def test_bnb_unsupported_weight_dtype_errors(self):
"""peft.backend: bnb + unsupported weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="not supported with bnb"):
validate_config(cfg)
# --- Redundant flags don't conflict ---
def test_bnb_nf4_with_explicit_load_in_4bit(self):
"""peft.backend: bnb + weight_dtype: nf4 + load_in_4bit: true is fine (redundant)"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_4bit": True,
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True