Compare commits

..

17 Commits

Author SHA1 Message Date
Wing Lian
d260eeb57d match protected method 2026-02-15 07:55:55 -05:00
Wing Lian
5a7f007d20 cleanup ao fp8 patching 2026-02-13 17:02:23 -05:00
Wing Lian
5eb265513c fix generic patch for cce (#3405) 2026-02-12 08:58:04 -05:00
NanoCode012
06ac407b92 feat: improve telemetry log (#3398)
* fix: redact trackio and data_files

* fix: add new orgs to whitelist

* feat: add run id to logs for users to easily share

* fix: update to add more metrics

* fix: add missed experiment tracker

* chore: formatting in main
2026-02-10 23:01:34 +07:00
NanoCode012
4e22cf0651 fix: remove telemetry warning (#3397) [skip ci] 2026-02-10 23:01:16 +07:00
VED
a4ee56c315 fix: set rollout in GRPO training_kwargs (#3392) 2026-02-10 18:06:15 +07:00
NanoCode012
c67cbcb0f5 fix: ignore add_special_tokens and use test mode for generation for mistral tokenizer (#3396) [skip ci]
* fix: ignore add_special_tokens and use test mode for generation

* fix: incorrectly setting kwarg
2026-02-10 18:03:26 +07:00
NanoCode012
a2da852576 fix: improve lora kernels failure message and handle trust_remote_code (#3378) [skip ci]
* fix: improve lora kernels failure message and handle trust_remote_code

* chore: re-order model guides
2026-02-10 17:58:40 +07:00
madScientist10
37e9da7a53 add hub_revision support for specifying branch when pushing checkpoints (#3387) [skip ci] 2026-02-10 17:53:09 +07:00
NanoCode012
ed7105dba7 fix: GRPO config not accept max_prompt_length (#3390) [skip ci] 2026-02-10 17:52:09 +07:00
NanoCode012
b6d3653f74 feat: add step3p5 for cce (#3384) [skip ci]
* feat: add step3p5 for cce

* chore: reorder model
2026-02-10 17:51:43 +07:00
NanoCode012
fcc4cfdb63 feat: add sageattention (#2823) [skip ci]
* feat: add sageattention

* feat: call path on pre model load

* fix: patch to use register to correct var

* fix: add strict check import at start

* chore: fix comments

* chore: refactor

* feat: add capability check

* fix: missed underscore

* fix: let sageattention use FA backend in transformers

* feat: update sage attention for attention mask and position ids

* feat: allow sample packing but add warning without packing

* fix: loss hitting 0 with packing and attention mask note

* feat: downcast embeds if sage attention too

* feat: add config validation

* feat: add attention docs

* chore: docs
2026-02-10 17:49:21 +07:00
VED
97a4f28511 fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none

* patch calculating  eval loss for cp
2026-02-10 17:47:26 +07:00
VED
86a5803212 train_per_sec_per_gpu metric (#3364) [skip ci]
* fix token count

* guard for none n zero
2026-02-10 17:44:55 +07:00
tgoab
530a0c0bf0 Changes from dataset_processes to dataset_num_proc (#3352) [skip ci]
* changes from dataset_processes to dataset_num_proc

* deprecation message improved

---------

Co-authored-by: Juliana Nieto Cárdenas <jnietoca@purdue.edu>
2026-02-10 17:44:17 +07:00
VED
0343a72cc9 add glm support + patch (#3329) [skip ci]
* add glm support + patch

* lint

* lint

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm4/glm-4-6v-flash-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update src/axolotl/processing_strategies.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patch removed

* lint

* lint2

* docs + rename

* rmv moe

* docs

* removed processor

* sdpa T_T"

* ddp_find_unused_parameters: true

* muti gpu yaml tested both

* muti gpu yaml tested both

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/glm46v/README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* rmv text only section + v5 comments

* rename

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2026-02-10 17:43:53 +07:00
Wing Lian
236dad3bb7 set 0.15.0.dev0 version (#3380) 2026-01-30 21:28:01 -05:00
39 changed files with 916 additions and 189 deletions

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- | | --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | | `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub | | `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes | | `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory | | `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | | `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging | | `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,7 +39,6 @@
# type: # linear | dynamic # type: # linear | dynamic
# factor: # float # factor: # float
# # Whether you are training a 4-bit GPTQ quantized model # # Whether you are training a 4-bit GPTQ quantized model
# gptq: true # gptq: true
# gptq_groupsize: 128 # group size # gptq_groupsize: 128 # group size
@@ -107,7 +106,7 @@
# push_dataset_to_hub: # repo path # push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set. # # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set # dataset_num_proc: # defaults to os.cpu_count() if not set
# # push checkpoints to hub # # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model # hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub # # how to push checkpoints to hub
@@ -349,8 +348,6 @@
# # Allow overwrite yml config using from cli # # Allow overwrite yml config using from cli
# strict: # strict:
base_model: ${BASE_MODEL} base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS} base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG} base_model_config: ${BASE_MODEL_CONFIG}
@@ -409,7 +406,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE} default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH} dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB} push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES} dataset_num_proc: ${DATASET_NUM_PROC}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY} dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID} hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY} hub_strategy: ${HUB_STRATEGY}

View File

@@ -1 +1 @@
0.14.0 0.15.0.dev0

View File

@@ -251,7 +251,6 @@ website:
- docs/models/olmo3.qmd - docs/models/olmo3.qmd
- docs/models/trinity.qmd - docs/models/trinity.qmd
- docs/models/arcee.qmd - docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3" - section: "Ministral3"
contents: contents:
- docs/models/ministral3.qmd - docs/models/ministral3.qmd
@@ -266,6 +265,7 @@ website:
- docs/models/mistral-small.qmd - docs/models/mistral-small.qmd
- docs/models/voxtral.qmd - docs/models/voxtral.qmd
- docs/models/devstral.qmd - docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd - docs/models/llama-4.qmd
- docs/models/llama-2.qmd - docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd - docs/models/qwen3-next.qmd
@@ -320,6 +320,7 @@ website:
- docs/multipack.qmd - docs/multipack.qmd
- docs/mixed_precision.qmd - docs/mixed_precision.qmd
- docs/optimizers.qmd - docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features" - section: "Advanced Features"
contents: contents:

140
docs/attention.qmd Normal file
View File

@@ -0,0 +1,140 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -89,6 +89,10 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT. Currently, LoRA kernels are not supported for RLHF training, only SFT.
::: :::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements ## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,6 +19,7 @@ format:
- [Gemma-3n](#sec-gemma-3n) - [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2) - [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl) - [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl) - [Intern-VL](#sec-intern-vl)
@@ -183,6 +184,18 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2} ### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
] ]
}, },
{ {

44
examples/glm46v/README.md Normal file
View File

@@ -0,0 +1,44 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,53 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,50 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
) )

View File

@@ -258,6 +258,11 @@ class TrainerBuilderBase(abc.ABC):
bf16 = bf16 if bf16 is not None else False bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16 training_args_kwargs["bf16"] = bf16
if self.cfg.fp8:
training_args_kwargs["fp8"] = True
if self.cfg.fp8_enable_fsdp_float8_all_gather:
training_args_kwargs["enable_fsdp_float8_all_gather:"] = True
def _configure_scheduler(self, training_args_kwargs: dict): def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]: if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["lr_scheduler_type"] = "cosine"
@@ -409,6 +414,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict): def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps # save_strategy and save_steps
if self.cfg.save_steps: if self.cfg.save_steps:

View File

@@ -584,11 +584,9 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
def additional_accelerator_args( def build_fp8_accelerator_args(self) -> dict[str, Any]:
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs args = {}
) -> dict[str, Any]: if self.args.fp8:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig from torchao.float8 import Float8LinearConfig
@@ -596,15 +594,22 @@ class AxolotlTrainer(
# scaling strategy. See more details here: # scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8. # https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig( config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, enable_fsdp_float8_all_gather=self.args.enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True, force_recompute_fp8_weight_in_bwd=self.args.enable_fsdp_float8_all_gather
is True,
) )
ret_kwargs["mixed_precision"] = "fp8" args["mixed_precision"] = "fp8"
ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore args["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return ret_kwargs return args
def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
args = super().build_accelerator_args(**kwargs)
fp8_args = self.build_fp8_accelerator_args()
args.update(fp8_args)
return args
def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
""" """
@@ -719,6 +724,13 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}") LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = ( supported_classes = (
(PreTrainedModel,) (PreTrainedModel,)
if not is_peft_available() if not is_peft_available()

View File

@@ -126,9 +126,6 @@ class GRPOStrategy:
if trl.use_liger_loss is not None: if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None: if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = ( grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation trl.multi_objective_aggregation
@@ -154,6 +151,8 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes cfg.trl.reward_processing_classes
) )
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs return trainer_kwargs
@@ -164,7 +163,12 @@ class GRPOStrategy:
@classmethod @classmethod
def get_blocklist_args_kwargs(cls) -> list[str]: def get_blocklist_args_kwargs(cls) -> list[str]:
return ["dataset_num_proc", "max_length", "include_tokens_per_second"] return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod @classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -263,3 +263,13 @@ class AxolotlTrainingMixins:
dion_rank_multiple_of: int | None = field( dion_rank_multiple_of: int | None = field(
default=None, default=None,
) )
fp8: bool | None = field(
default=None,
metadata={"help": "Whether to use FP8 precision for training"},
)
enable_fsdp_float8_all_gather: bool | None = field(
default=None,
metadata={"help": "Whether to use FSDP with FP8 precision for all_gather"},
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"
``` ```
## Usage ## Usage
@@ -54,8 +54,8 @@ plugins:
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
- granitemoeshared
- granitemoehybrid - granitemoehybrid
- granitemoeshared
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- internvl - internvl
@@ -80,16 +80,17 @@ plugins:
- phi3 - phi3
- phi4_multimodal - phi4_multimodal
- qwen2 - qwen2
- qwen2_vl
- qwen2_moe - qwen2_moe
- qwen2_vl
- qwen2_5_vl - qwen2_5_vl
- qwen3 - qwen3
- qwen3_moe - qwen3_moe
- qwen3_next
- qwen3_vl - qwen3_vl
- qwen3_vl_moe - qwen3_vl_moe
- qwen3_next
- smollm3
- seed_oss - seed_oss
- smollm3
- step3p5
- voxtral - voxtral
## Citation ## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f4b5712"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`'
) )
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like( def patch_llama_like(
self, self,
model_type: str, model_type_to_patch: str,
) -> None: ) -> None:
""" """
Generic patch for model architectures with causal lm similar to llama Generic patch for model architectures with causal lm similar to llama
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic( def patch_generic(
maybe_model, patch_options, model_type: str, remote_model_id: str | None maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
): ):
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e
if model_type not in PATCH_FNS: if model_type_to_patch not in PATCH_FNS:
LOG.warning_once( LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type "Setting up generic cce patch for model type: %s", model_type_to_patch
) )
LOG.warning_once( LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
) )
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -338,7 +338,12 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility. # we need to convert them back to fp16/bf16 for flash-attn compatibility.
( (
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) (
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
) )
or ( or (
@@ -612,6 +617,10 @@ class ModelLoader:
elif self.cfg.sdp_attention: elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa" self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa" self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention: elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager" self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager" self.model_config._attn_implementation = "eager"

View File

@@ -96,10 +96,10 @@ class PatchManager:
# self._apply_flex_attention_patches() # self._apply_flex_attention_patches()
self._apply_flash_attention_patches() self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch() self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches() self._apply_fsdp_patches()
self._apply_adapter_patches() self._apply_adapter_patches()
self._apply_model_specific_patches() self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches() self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches() self._apply_gradient_checkpointing_patches()
self._patch_attention() self._patch_attention()
@@ -201,6 +201,13 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs) patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self): def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures.""" """Apply patches specific to model architectures."""
if ( if (
@@ -227,17 +234,6 @@ class PatchManager:
patch_kimi_model() patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self): def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT.""" """Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter: if self.cfg.adapter:

View File

@@ -0,0 +1,211 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -169,7 +169,8 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise ValueError( raise ValueError(
f"Could not import attention class for model_type: {model_type}. " f"Axolotl could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e

View File

@@ -1,83 +0,0 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = create_code
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals())
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = (
fixed_create_accelerator_and_postprocess
)

View File

@@ -485,6 +485,58 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -501,10 +553,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm, "image_resize_algorithm": image_resize_algorithm,
} }
if chat_template_type in [None, "tokenizer_default"] and hasattr( if chat_template_type in [None, "tokenizer_default"]:
processor.tokenizer, "chat_template" tokenizer = getattr(processor, "tokenizer", processor)
): if hasattr(tokenizer, "chat_template"):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template processing_kwargs["chat_template"] = tokenizer.chat_template
if chat_template_type == "qwen2_vl": if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy( return Qwen2VLProcessingStrategy(
@@ -533,6 +585,15 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy( return Mistral3ProcessingStrategy(
**processing_kwargs, **processing_kwargs,
) )
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor): if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy( return InternVLProcessingStrategy(

View File

@@ -153,13 +153,27 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict: def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, and grad_norm from log history.""" """Extract last loss, learning_rate, grad_norm, and token metrics from log history."""
if not state.log_history: if not state.log_history:
return {"loss": 0, "learning_rate": 0, "grad_norm": 0} return {
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
last_log = state.log_history[-1] last_log = state.log_history[-1]
return { return {
"loss": last_log.get("loss", 0), "loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0), "learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0), "grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
} }

View File

@@ -155,6 +155,10 @@ def send_errors(func: Callable) -> Callable:
}, },
) )
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise raise
return wrapper return wrapper

View File

@@ -5,7 +5,6 @@ import importlib
import logging import logging
import os import os
import platform import platform
import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -20,21 +19,6 @@ LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com" POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes # NOTE: Need to keep these up to date with any config schema changes
@@ -46,8 +30,8 @@ FIELDS_TO_REDACT = {
"resume_from_checkpoint", "resume_from_checkpoint",
"hub_model_id", "hub_model_id",
} }
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"} PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_", "trackio_", "swanlab_"}
PATH_INDICATORS = {"path", "dir"} PATH_INDICATORS = {"path", "dir", "data_files"}
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
RELEVANT_PACKAGES = { RELEVANT_PACKAGES = {
@@ -183,11 +167,6 @@ class TelemetryManager:
"false", "false",
"true", "true",
): ):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True return True
# Only rank 0 will send telemetry # Only rank 0 will send telemetry

View File

@@ -31,3 +31,10 @@ organizations:
- "mistral-community" - "mistral-community"
- "llava-hf" - "llava-hf"
- "ByteDance-Seed" - "ByteDance-Seed"
- "ACE-Step"
- "openbmb"
- "MiniMaxAI"
- "stepfun-ai"
- "internlm"
- "katanemo"
- "XiaomiMiMo"

View File

@@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
tokens = getattr(state, "tokens", None) tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens: if not (tokens and "trainable_tokens" in tokens):
step_time = time.perf_counter() - self.start_time return
num_tokens_per_device = tokens["trainable_tokens"].clone() step_time = time.perf_counter() - self.start_time
# non data parallel groups have duplicated tokens, so we avoid double-counting if step_time <= 0:
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size return
state.last_tokens_per_second = num_tokens_per_device / step_time
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
if torch.distributed.is_initialized():
dp_size = max(
1, torch.distributed.get_world_size() // self.non_data_parallel_size
)
num_tokens = num_tokens / dp_size
state.last_tokens_per_second = num_tokens / step_time
def on_log( def on_log(
self, self,

View File

@@ -218,6 +218,9 @@ class SequenceParallelContextManager:
self.original_seq_len = 0 self.original_seq_len = 0
self.pad_len = 0 self.pad_len = 0
# Track local valid token count for eval loss correction across CP ranks
self._local_valid_tokens: torch.Tensor | None = None
# Create a partially applied version of the apply_sequence_parallelism function # Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial( self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism, apply_sequence_parallelism,
@@ -270,6 +273,18 @@ class SequenceParallelContextManager:
self.apply_sequence_parallelism(updated_kwargs) self.apply_sequence_parallelism(updated_kwargs)
) )
# Track local valid tokens for eval loss correction
if "labels" in updated_kwargs and not self.models[0].training:
self._local_valid_tokens = (
(updated_kwargs["labels"] != -100).sum().float()
)
# Strip num_items_in_batch during eval so the model uses
# reduction='mean', allowing the post-hook weighted all-reduce
# formula (loss * local_valid) to correctly recover the loss sum
updated_kwargs.pop("num_items_in_batch", None)
else:
self._local_valid_tokens = None
return remaining_args, updated_kwargs return remaining_args, updated_kwargs
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
@@ -287,6 +302,44 @@ class SequenceParallelContextManager:
return output return output
# Post-hook to correct eval loss via weighted all-reduce across CP ranks
def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:
if self._local_valid_tokens is None:
return output
if not hasattr(output, "loss") or output.loss is None:
return output
local_valid = self._local_valid_tokens.to(output.loss.device)
loss = output.loss.detach().clone()
# Handle rank with zero valid tokens (loss is NaN)
if local_valid.item() == 0:
weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)
else:
weighted_loss = loss * local_valid
total_valid = local_valid.clone()
dist.all_reduce(
weighted_loss,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
dist.all_reduce(
total_valid,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
if total_valid.item() > 0:
output["loss"] = (weighted_loss / total_valid).squeeze()
else:
output["loss"] = torch.tensor(
float("nan"), device=loss.device, dtype=loss.dtype
)
self._local_valid_tokens = None
return output
# Register hooks # Register hooks
for model in self.models: for model in self.models:
self.hook_handles.append( self.hook_handles.append(
@@ -298,6 +351,10 @@ class SequenceParallelContextManager:
self.hook_handles.append( self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook) model.register_forward_hook(sequence_parallel_post_hook)
) )
# Always register eval loss correction hook
self.hook_handles.append(
model.register_forward_hook(eval_loss_correction_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor.""" """Gather sharded outputs from all ranks and reconstruct the full tensor."""

View File

@@ -2,11 +2,19 @@
import os import os
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_default_process_count(): def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc) return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
LOG.warning(
"AXOLOTL_DATASET_PROCESSES and `dataset_processes` are deprecated and will be "
"removed in a future version. Please use `dataset_num_proc` instead."
)
return int(axolotl_dataset_processes) return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count) return int(runpod_cpu_count)

View File

@@ -86,15 +86,15 @@ class HFMistralTokenizer(MistralCommonBackend):
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
**kwargs, **kwargs,
) -> str | list[int]: ) -> str | list[int]:
"""Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg""" """Patched fn to handle setting test mode, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral # pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None) kwargs.pop("real_last_index", None)
kwargs.pop("add_special_tokens", None)
try: try:
if add_generation_prompt: if add_generation_prompt:
self._set_mode(ValidationMode.serving) self._set_mode(ValidationMode.test)
kwargs["continue_final_message"] = True
out = super().apply_chat_template(conversation, **kwargs) out = super().apply_chat_template(conversation, **kwargs)

View File

@@ -609,6 +609,12 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "Whether to use bettertransformers"}, json_schema_extra={"description": "Whether to use bettertransformers"},
) )
sage_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
},
)
eager_attention: bool | None = None eager_attention: bool | None = None
@@ -1120,6 +1126,27 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"):
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options""" """Wrapper to valdiate GPU capabilities with the configured options"""
@@ -1176,6 +1203,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data return data
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
if (
data.get("sage_attention")
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
):
raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_multigpu_unsloth(cls, data): def check_multigpu_unsloth(cls, data):
@@ -1229,6 +1271,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
): ):
return data return data
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0: if data.get("lora_dropout") != 0:
return data return data

View File

@@ -120,6 +120,12 @@ class ModelOutputConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "how to push checkpoints to hub"}, json_schema_extra={"description": "how to push checkpoints to hub"},
) )
hub_revision: str | None = Field(
default=None,
json_schema_extra={
"description": "branch/revision to push to on hub (default: main)"
},
)
save_safetensors: bool | None = Field( save_safetensors: bool | None = Field(
default=True, default=True,
json_schema_extra={ json_schema_extra={

View File

@@ -166,9 +166,10 @@ class AttentionValidationMixin:
fields = ( fields = (
"xformers_attention", "xformers_attention",
"sdp_attention", "sdp_attention",
"s2_attention", # "s2_attention", # requires both FA and this to be enabled
"flash_attention", "flash_attention",
"flex_attention", "flex_attention",
"sage_attention",
) )
non_empty_count = sum(1 for field in fields if data.get(field)) non_empty_count = sum(1 for field in fields if data.get(field))
@@ -185,9 +186,10 @@ class AttentionValidationMixin:
and not data.get("sdp_attention") and not data.get("sdp_attention")
and not data.get("flex_attention") and not data.get("flex_attention")
and not data.get("xformers_attention") and not data.get("xformers_attention")
and not data.get("sage_attention")
): ):
LOG.warning( LOG.warning(
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." "sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination."
) )
return data return data
@@ -688,6 +690,21 @@ class LoRAValidationMixin:
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("trust_remote_code"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with trust_remote_code. Please disable trust_remote_code "
"or explicitly set lora_*_kernel to false."
)
return data
class RLValidationMixin: class RLValidationMixin:
"""Validation methods related to RL training configuration.""" """Validation methods related to RL training configuration."""

View File

@@ -79,7 +79,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800, "ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25, "ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False, "ddp_broadcast_buffers": False,
"dataset_processes": 4, "dataset_num_proc": 4,
} }
) )

View File

@@ -30,7 +30,7 @@ class TestStreamingDatasets:
"sample_packing": sample_packing, "sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing, "pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000, "streaming_multipack_buffer_size": 10000,
"dataset_processes": 1, "dataset_num_proc": 1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },

View File

@@ -118,20 +118,6 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
assert not manager.enabled assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
assert any(
"Telemetry is now enabled by default" in str(call)
for call in mock_warning.call_args_list
)
def test_is_whitelisted(telemetry_manager_class, mock_whitelist): def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality""" """Test org whitelist functionality"""
with ( with (

View File

@@ -90,3 +90,62 @@ class TestLoRAConfigValidation:
} }
) )
validate_config(invalid_config) validate_config(invalid_config)
@pytest.mark.parametrize(
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
"""Test that lora kernels are incompatible with trust_remote_code"""
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
invalid_config = DictDefault(
{
"adapter": "lora",
kernel_field: True,
"trust_remote_code": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_lora_kernels_trust_remote_code_false(self):
"""Test that lora kernels work when trust_remote_code is false"""
# Test with trust_remote_code=False, lora kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"trust_remote_code": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_mlp_kernel"] is True
assert result["lora_qkv_kernel"] is True
assert result["lora_o_kernel"] is True
# Test with trust_remote_code=None (unset), kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_qkv_kernel": True,
"trust_remote_code": None,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None