Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
87e0fd6b52 feat: add glm 4.7 flash 2026-02-10 18:57:20 +07:00
NanoCode012
2d44432e6c chore: update trinity docs 2026-02-04 18:10:33 +07:00
NanoCode012
57377814e9 feat: update cce for afmoe 2026-02-04 18:00:23 +07:00
42 changed files with 298 additions and 925 deletions

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- | | --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset | | `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub | | `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_num_proc` | `4` | Number of preprocessing processes | | `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory | | `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets | | `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging | | `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,6 +39,7 @@
# type: # linear | dynamic # type: # linear | dynamic
# factor: # float # factor: # float
# # Whether you are training a 4-bit GPTQ quantized model # # Whether you are training a 4-bit GPTQ quantized model
# gptq: true # gptq: true
# gptq_groupsize: 128 # group size # gptq_groupsize: 128 # group size
@@ -106,7 +107,7 @@
# push_dataset_to_hub: # repo path # push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set. # # if not set.
# dataset_num_proc: # defaults to os.cpu_count() if not set # dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub # # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model # hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub # # how to push checkpoints to hub
@@ -348,6 +349,8 @@
# # Allow overwrite yml config using from cli # # Allow overwrite yml config using from cli
# strict: # strict:
base_model: ${BASE_MODEL} base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS} base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG} base_model_config: ${BASE_MODEL_CONFIG}
@@ -406,7 +409,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE} default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH} dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB} push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_num_proc: ${DATASET_NUM_PROC} dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY} dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID} hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY} hub_strategy: ${HUB_STRATEGY}

View File

@@ -251,6 +251,7 @@ website:
- docs/models/olmo3.qmd - docs/models/olmo3.qmd
- docs/models/trinity.qmd - docs/models/trinity.qmd
- docs/models/arcee.qmd - docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3" - section: "Ministral3"
contents: contents:
- docs/models/ministral3.qmd - docs/models/ministral3.qmd
@@ -265,7 +266,6 @@ website:
- docs/models/mistral-small.qmd - docs/models/mistral-small.qmd
- docs/models/voxtral.qmd - docs/models/voxtral.qmd
- docs/models/devstral.qmd - docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd - docs/models/llama-4.qmd
- docs/models/llama-2.qmd - docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd - docs/models/qwen3-next.qmd
@@ -320,7 +320,6 @@ website:
- docs/multipack.qmd - docs/multipack.qmd
- docs/mixed_precision.qmd - docs/mixed_precision.qmd
- docs/optimizers.qmd - docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features" - section: "Advanced Features"
contents: contents:

View File

@@ -1,140 +0,0 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -89,10 +89,6 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT. Currently, LoRA kernels are not supported for RLHF training, only SFT.
::: :::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements ## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) - One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,7 +19,6 @@ format:
- [Gemma-3n](#sec-gemma-3n) - [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl) - [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl) - [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2) - [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl) - [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl) - [Intern-VL](#sec-intern-vl)
@@ -184,18 +183,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl chat_template: qwen2_vl # same as qwen2-vl
``` ```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2} ### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip} ::: {.callout-tip}

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d\""
] ]
}, },
{ {

View File

@@ -0,0 +1,40 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Run the finetuning example:
```bash
axolotl train examples/glm4.7-flash/glm4.7-flash-qlora.yaml
```
This config uses about X GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Z.ai team recommends `top_p: 0.95`, `temperature: 1.0`, and `max_new_tokens: 131072`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,63 @@
base_model: zai-org/GLM-4.7-Flash
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project: glm-4.7-flash
wandb_entity:
wandb_watch:
wandb_name: qlora
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,44 +0,0 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,53 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,50 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build). 1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example: 2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash ```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
``` ```
This config uses about 24.9 GiB VRAM. This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀 Let us know how it goes. Happy finetuning! 🚀
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources ## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto) - [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,13 +1,11 @@
base_model: arcee-ai/Trinity-Nano-Preview base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0 revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF # Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name # hub_model_id: username/custom_model_name
# CCE - N/A as of now plugins:
# plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"'
) )

View File

@@ -258,11 +258,6 @@ class TrainerBuilderBase(abc.ABC):
bf16 = bf16 if bf16 is not None else False bf16 = bf16 if bf16 is not None else False
training_args_kwargs["bf16"] = bf16 training_args_kwargs["bf16"] = bf16
if self.cfg.fp8:
training_args_kwargs["fp8"] = True
if self.cfg.fp8_enable_fsdp_float8_all_gather:
training_args_kwargs["enable_fsdp_float8_all_gather:"] = True
def _configure_scheduler(self, training_args_kwargs: dict): def _configure_scheduler(self, training_args_kwargs: dict):
if self.cfg.lr_scheduler in ["one_cycle", "rex"]: if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
training_args_kwargs["lr_scheduler_type"] = "cosine" training_args_kwargs["lr_scheduler_type"] = "cosine"
@@ -414,9 +409,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy: if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict): def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps # save_strategy and save_steps
if self.cfg.save_steps: if self.cfg.save_steps:

View File

@@ -584,9 +584,11 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess() super().create_accelerator_and_postprocess()
def build_fp8_accelerator_args(self) -> dict[str, Any]: def additional_accelerator_args(
args = {} self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
if self.args.fp8: ) -> dict[str, Any]:
ret_kwargs = {}
if fp8:
from accelerate.utils import AORecipeKwargs from accelerate.utils import AORecipeKwargs
from torchao.float8 import Float8LinearConfig from torchao.float8 import Float8LinearConfig
@@ -594,22 +596,15 @@ class AxolotlTrainer(
# scaling strategy. See more details here: # scaling strategy. See more details here:
# https://github.com/pytorch/ao/tree/main/torchao/float8. # https://github.com/pytorch/ao/tree/main/torchao/float8.
config = Float8LinearConfig( config = Float8LinearConfig(
enable_fsdp_float8_all_gather=self.args.enable_fsdp_float8_all_gather, enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=self.args.enable_fsdp_float8_all_gather force_recompute_fp8_weight_in_bwd=enable_fsdp_float8_all_gather is True,
is True,
) )
args["mixed_precision"] = "fp8" ret_kwargs["mixed_precision"] = "fp8"
args["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore ret_kwargs["kwargs_handlers"] = [AORecipeKwargs(config=config)] # type: ignore
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
return args return ret_kwargs
def _build_accelerator_args(self, **kwargs) -> dict[str, Any]:
args = super().build_accelerator_args(**kwargs)
fp8_args = self.build_fp8_accelerator_args()
args.update(fp8_args)
return args
def log(self, logs: dict[str, float], start_time: float | None = None) -> None: def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
""" """
@@ -724,13 +719,6 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}") LOG.info(f"Saving model checkpoint to {output_dir}")
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = ( supported_classes = (
(PreTrainedModel,) (PreTrainedModel,)
if not is_peft_available() if not is_peft_available()

View File

@@ -126,6 +126,9 @@ class GRPOStrategy:
if trl.use_liger_loss is not None: if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
if trl.multi_objective_aggregation is not None: if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = ( grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation trl.multi_objective_aggregation
@@ -151,8 +154,6 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = ( trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes cfg.trl.reward_processing_classes
) )
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs return trainer_kwargs
@@ -163,12 +164,7 @@ class GRPOStrategy:
@classmethod @classmethod
def get_blocklist_args_kwargs(cls) -> list[str]: def get_blocklist_args_kwargs(cls) -> list[str]:
return [ return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
@classmethod @classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc: def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -263,13 +263,3 @@ class AxolotlTrainingMixins:
dion_rank_multiple_of: int | None = field( dion_rank_multiple_of: int | None = field(
default=None, default=None,
) )
fp8: bool | None = field(
default=None,
metadata={"help": "Whether to use FP8 precision for training"},
)
enable_fsdp_float8_all_gather: bool | None = field(
default=None,
metadata={"help": "Whether to use FSDP with FP8 precision for all_gather"},
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"
``` ```
## Usage ## Usage
@@ -31,6 +31,7 @@ plugins:
## Supported Models ## Supported Models
- afmoe
- apertus - apertus
- arcee - arcee
- cohere - cohere
@@ -54,8 +55,8 @@ plugins:
- gpt_oss - gpt_oss
- granite - granite
- granitemoe - granitemoe
- granitemoehybrid
- granitemoeshared - granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense - hunyuan_v1_dense
- hunyuan_v1_moe - hunyuan_v1_moe
- internvl - internvl
@@ -80,17 +81,16 @@ plugins:
- phi3 - phi3
- phi4_multimodal - phi4_multimodal
- qwen2 - qwen2
- qwen2_moe
- qwen2_vl - qwen2_vl
- qwen2_moe
- qwen2_5_vl - qwen2_5_vl
- qwen3 - qwen3
- qwen3_moe - qwen3_moe
- qwen3_next
- qwen3_vl - qwen3_vl
- qwen3_vl_moe - qwen3_vl_moe
- seed_oss - qwen3_next
- smollm3 - smollm3
- step3p5 - seed_oss
- voxtral - voxtral
## Citation ## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@e39ca1d"`'
) )
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like( def patch_llama_like(
self, self,
model_type_to_patch: str, model_type: str,
) -> None: ) -> None:
""" """
Generic patch for model architectures with causal lm similar to llama Generic patch for model architectures with causal lm similar to llama
@@ -112,10 +112,7 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic( def patch_generic(
maybe_model, maybe_model, patch_options, model_type: str, remote_model_id: str | None
patch_options,
remote_model_id: str | None,
model_type: str,
): ):
import cut_cross_entropy.transformers.llama import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward from cut_cross_entropy.transformers.llama import cce_forward
@@ -139,13 +136,11 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e
if model_type_to_patch not in PATCH_FNS: if model_type not in PATCH_FNS:
LOG.warning_once( LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type_to_patch "Setting up generic cce patch for model type: %s", model_type
) )
LOG.warning_once( LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected." f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
) )
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -338,12 +338,7 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility. # we need to convert them back to fp16/bf16 for flash-attn compatibility.
( (
( (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
and not self.is_qlora_and_fsdp_enabled and not self.is_qlora_and_fsdp_enabled
) )
or ( or (
@@ -617,10 +612,6 @@ class ModelLoader:
elif self.cfg.sdp_attention: elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa" self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa" self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention: elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager" self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager" self.model_config._attn_implementation = "eager"

View File

@@ -96,10 +96,10 @@ class PatchManager:
# self._apply_flex_attention_patches() # self._apply_flex_attention_patches()
self._apply_flash_attention_patches() self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch() self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches() self._apply_fsdp_patches()
self._apply_adapter_patches() self._apply_adapter_patches()
self._apply_model_specific_patches() self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches() self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches() self._apply_gradient_checkpointing_patches()
self._patch_attention() self._patch_attention()
@@ -201,13 +201,6 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs) patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self): def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures.""" """Apply patches specific to model architectures."""
if ( if (
@@ -234,6 +227,17 @@ class PatchManager:
patch_kimi_model() patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self): def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT.""" """Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter: if self.cfg.adapter:

View File

@@ -1,211 +0,0 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

View File

@@ -169,8 +169,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
raise ValueError( raise ValueError(
f"Axolotl could not import attention class for model_type: {model_type}. " f"Could not import attention class for model_type: {model_type}. "
"Please raise an Issue and turn off lora kernels to continue training. "
f"Error: {str(e)}" f"Error: {str(e)}"
) from e ) from e

View File

@@ -0,0 +1,83 @@
"""
allow adding additional kwargs to Accelerator init
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object
self.accelerator = Accelerator(**args)
"""
PATCHED_TRAINER_CODE = """
if hasattr(self, "additional_accelerator_args"):
additional_args = self.additional_accelerator_args(fp8=True, enable_fsdp_float8_all_gather={enable_fsdp_float8_all_gather}, **args)
if additional_args:
args.update(additional_args)
# create accelerator object
self.accelerator = Accelerator(**args)
"""
def get_create_accelerate_code() -> str:
training_loop = inspect.getsource(Trainer.create_accelerator_and_postprocess)
return training_loop
def check_create_accelerate_code_is_patchable() -> bool:
create_code = get_create_accelerate_code()
create_code, _ = detab_code(create_code)
return ORIGINAL_TRAINER_CODE in create_code
def patch_create_accelerate_code_for_fp8(enable_fsdp_float8_all_gather: bool):
"""
Monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs.
"""
try:
create_code = get_create_accelerate_code()
except OSError:
return
Trainer._original_create_accelerator_and_postprocess = create_code
create_code, _ = detab_code(create_code)
if ORIGINAL_TRAINER_CODE not in create_code:
return
patched_trainer_code = PATCHED_TRAINER_CODE.format(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather
)
create_code = create_code.replace(ORIGINAL_TRAINER_CODE, patched_trainer_code)
create_code = create_code.replace(
"def create_accelerator_and_postprocess(",
"def fixed_create_accelerator_and_postprocess(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in create_code:
items_to_import.append(item)
exec(
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(create_code, globals())
LOG.info("patching create_accelerator_and_postprocess to allow for overrides")
Trainer.create_accelerator_and_postprocess = (
fixed_create_accelerator_and_postprocess
)

View File

@@ -485,58 +485,6 @@ class InternVLProcessingStrategy(ProcessingStrategy):
return labels return labels
class Glm4vProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for GLM4V and GLM4V-MoE vision models."""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.tokenizer = getattr(processor, "tokenizer", processor)
self.image_token = "<|image|>" # nosec
self.begin_image_token = "<|begin_of_image|>" # nosec
self.end_image_token = "<|end_of_image|>" # nosec
self.video_token = "<|video|>" # nosec
self.begin_video_token = "<|begin_of_video|>" # nosec
self.end_video_token = "<|end_of_video|>" # nosec
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
self.begin_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_image_token
)
self.end_image_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_image_token
)
self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token)
self.begin_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.begin_video_token
)
self.end_video_token_id = self.tokenizer.convert_tokens_to_ids(
self.end_video_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
labels[labels == self.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == self.begin_image_token_id] = -100
labels[labels == self.end_image_token_id] = -100
labels[labels == self.video_token_id] = -100
labels[labels == self.begin_video_token_id] = -100
labels[labels == self.end_video_token_id] = -100
return labels
def get_processing_strategy( def get_processing_strategy(
processor: ProcessorMixin, processor: ProcessorMixin,
chat_template, chat_template,
@@ -553,10 +501,10 @@ def get_processing_strategy(
"image_resize_algorithm": image_resize_algorithm, "image_resize_algorithm": image_resize_algorithm,
} }
if chat_template_type in [None, "tokenizer_default"]: if chat_template_type in [None, "tokenizer_default"] and hasattr(
tokenizer = getattr(processor, "tokenizer", processor) processor.tokenizer, "chat_template"
if hasattr(tokenizer, "chat_template"): ):
processing_kwargs["chat_template"] = tokenizer.chat_template processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl": if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy( return Qwen2VLProcessingStrategy(
@@ -585,15 +533,6 @@ def get_processing_strategy(
return Mistral3ProcessingStrategy( return Mistral3ProcessingStrategy(
**processing_kwargs, **processing_kwargs,
) )
try:
from transformers.models.glm46v.processing_glm46v import Glm46VProcessor
if isinstance(processor, Glm46VProcessor):
return Glm4vProcessingStrategy(
**processing_kwargs,
)
except ImportError:
pass
if isinstance(processor, InternVLProcessor): if isinstance(processor, InternVLProcessor):
return InternVLProcessingStrategy( return InternVLProcessingStrategy(

View File

@@ -153,27 +153,13 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict: def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss, learning_rate, grad_norm, and token metrics from log history.""" """Extract last loss, learning_rate, and grad_norm from log history."""
if not state.log_history: if not state.log_history:
return { return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
"loss": 0,
"ppl": 0,
"learning_rate": 0,
"grad_norm": 0,
"tokens/total": 0,
"tokens/trainable": 0,
"tokens/train_per_sec_per_gpu": 0,
}
last_log = state.log_history[-1] last_log = state.log_history[-1]
return { return {
"loss": last_log.get("loss", 0), "loss": last_log.get("loss", 0),
"ppl": last_log.get("ppl", 0),
"learning_rate": last_log.get("learning_rate", 0), "learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0), "grad_norm": last_log.get("grad_norm", 0),
"tokens/total": last_log.get("tokens/total", 0),
"tokens/trainable": last_log.get("tokens/trainable", 0),
"tokens/train_per_sec_per_gpu": last_log.get(
"tokens/train_per_sec_per_gpu", 0
),
} }

View File

@@ -155,10 +155,6 @@ def send_errors(func: Callable) -> Callable:
}, },
) )
LOG.error(
f"Error captured in telemetry. Run ID: {telemetry_manager.run_id}"
)
raise raise
return wrapper return wrapper

View File

@@ -5,6 +5,7 @@ import importlib
import logging import logging
import os import os
import platform import platform
import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -19,6 +20,21 @@ LOG = logging.getLogger(__name__)
POSTHOG_HOST = "https://app.posthog.com" POSTHOG_HOST = "https://app.posthog.com"
POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y" POSTHOG_WRITE_KEY = "phc_1kUR0o04oJKKTTeSsIz2Mfm5mpiVsQEf2WOlzljMD7y"
OPT_OUT_WARNING_SLEEP_SECONDS = 10
OPT_OUT_WARNING = (
"\nTelemetry is now enabled by default to help improve Axolotl. "
"If you'd like to disable it, set AXOLOTL_DO_NOT_TRACK=1 in your environment.\n\n"
"Telemetry data helps us understand:\n"
"- Which features are most used\n"
"- What hardware configurations to prioritize\n"
"- Where users encounter errors\n\n"
"Personally identifiable information (PII) is not collected.\n\n"
"To remove this warning, explicitly set AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) "
"or AXOLOTL_DO_NOT_TRACK=1 (disable telemetry).\n\n"
"For details, see: https://docs.axolotl.ai/docs/telemetry.html\n\n"
f"Sleeping for {OPT_OUT_WARNING_SLEEP_SECONDS}s..."
)
WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml") WHITELIST_PATH = str(Path(__file__).parent / "whitelist.yaml")
# NOTE: Need to keep these up to date with any config schema changes # NOTE: Need to keep these up to date with any config schema changes
@@ -30,8 +46,8 @@ FIELDS_TO_REDACT = {
"resume_from_checkpoint", "resume_from_checkpoint",
"hub_model_id", "hub_model_id",
} }
PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_", "trackio_", "swanlab_"} PREFIXES_TO_REDACT = {"wandb_", "comet_", "mlflow_", "gradio_"}
PATH_INDICATORS = {"path", "dir", "data_files"} PATH_INDICATORS = {"path", "dir"}
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
RELEVANT_PACKAGES = { RELEVANT_PACKAGES = {
@@ -167,6 +183,11 @@ class TelemetryManager:
"false", "false",
"true", "true",
): ):
# Print opt-out info message for main process only
if is_main_process():
LOG.warning(OPT_OUT_WARNING)
time.sleep(OPT_OUT_WARNING_SLEEP_SECONDS)
return True return True
# Only rank 0 will send telemetry # Only rank 0 will send telemetry

View File

@@ -31,10 +31,3 @@ organizations:
- "mistral-community" - "mistral-community"
- "llava-hf" - "llava-hf"
- "ByteDance-Seed" - "ByteDance-Seed"
- "ACE-Step"
- "openbmb"
- "MiniMaxAI"
- "stepfun-ai"
- "internlm"
- "katanemo"
- "XiaomiMiMo"

View File

@@ -78,19 +78,12 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
tokens = getattr(state, "tokens", None) tokens = getattr(state, "tokens", None)
if not (tokens and "trainable_tokens" in tokens): if tokens and "trainable_tokens" in tokens:
return step_time = time.perf_counter() - self.start_time
step_time = time.perf_counter() - self.start_time num_tokens_per_device = tokens["trainable_tokens"].clone()
if step_time <= 0: # non data parallel groups have duplicated tokens, so we avoid double-counting
return num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
if torch.distributed.is_initialized():
dp_size = max(
1, torch.distributed.get_world_size() // self.non_data_parallel_size
)
num_tokens = num_tokens / dp_size
state.last_tokens_per_second = num_tokens / step_time
def on_log( def on_log(
self, self,

View File

@@ -218,9 +218,6 @@ class SequenceParallelContextManager:
self.original_seq_len = 0 self.original_seq_len = 0
self.pad_len = 0 self.pad_len = 0
# Track local valid token count for eval loss correction across CP ranks
self._local_valid_tokens: torch.Tensor | None = None
# Create a partially applied version of the apply_sequence_parallelism function # Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial( self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism, apply_sequence_parallelism,
@@ -273,18 +270,6 @@ class SequenceParallelContextManager:
self.apply_sequence_parallelism(updated_kwargs) self.apply_sequence_parallelism(updated_kwargs)
) )
# Track local valid tokens for eval loss correction
if "labels" in updated_kwargs and not self.models[0].training:
self._local_valid_tokens = (
(updated_kwargs["labels"] != -100).sum().float()
)
# Strip num_items_in_batch during eval so the model uses
# reduction='mean', allowing the post-hook weighted all-reduce
# formula (loss * local_valid) to correctly recover the loss sum
updated_kwargs.pop("num_items_in_batch", None)
else:
self._local_valid_tokens = None
return remaining_args, updated_kwargs return remaining_args, updated_kwargs
# Forward post-hook to gather outputs # Forward post-hook to gather outputs
@@ -302,44 +287,6 @@ class SequenceParallelContextManager:
return output return output
# Post-hook to correct eval loss via weighted all-reduce across CP ranks
def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:
if self._local_valid_tokens is None:
return output
if not hasattr(output, "loss") or output.loss is None:
return output
local_valid = self._local_valid_tokens.to(output.loss.device)
loss = output.loss.detach().clone()
# Handle rank with zero valid tokens (loss is NaN)
if local_valid.item() == 0:
weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)
else:
weighted_loss = loss * local_valid
total_valid = local_valid.clone()
dist.all_reduce(
weighted_loss,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
dist.all_reduce(
total_valid,
op=dist.ReduceOp.SUM,
group=self.process_group,
)
if total_valid.item() > 0:
output["loss"] = (weighted_loss / total_valid).squeeze()
else:
output["loss"] = torch.tensor(
float("nan"), device=loss.device, dtype=loss.dtype
)
self._local_valid_tokens = None
return output
# Register hooks # Register hooks
for model in self.models: for model in self.models:
self.hook_handles.append( self.hook_handles.append(
@@ -351,10 +298,6 @@ class SequenceParallelContextManager:
self.hook_handles.append( self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook) model.register_forward_hook(sequence_parallel_post_hook)
) )
# Always register eval loss correction hook
self.hook_handles.append(
model.register_forward_hook(eval_loss_correction_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor.""" """Gather sharded outputs from all ranks and reconstruct the full tensor."""

View File

@@ -2,19 +2,11 @@
import os import os
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_default_process_count(): def get_default_process_count():
if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
return int(axolotl_dataset_num_proc) return int(axolotl_dataset_num_proc)
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
LOG.warning(
"AXOLOTL_DATASET_PROCESSES and `dataset_processes` are deprecated and will be "
"removed in a future version. Please use `dataset_num_proc` instead."
)
return int(axolotl_dataset_processes) return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count) return int(runpod_cpu_count)

View File

@@ -86,15 +86,15 @@ class HFMistralTokenizer(MistralCommonBackend):
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
**kwargs, **kwargs,
) -> str | list[int]: ) -> str | list[int]:
"""Patched fn to handle setting test mode, remove chat_template and add_generation_prompt kwarg""" """Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg"""
# pop unnecessary kwarg for mistral # pop unnecessary kwarg for mistral
kwargs.pop("real_last_index", None) kwargs.pop("real_last_index", None)
kwargs.pop("add_special_tokens", None)
try: try:
if add_generation_prompt: if add_generation_prompt:
self._set_mode(ValidationMode.test) self._set_mode(ValidationMode.serving)
kwargs["continue_final_message"] = True
out = super().apply_chat_template(conversation, **kwargs) out = super().apply_chat_template(conversation, **kwargs)

View File

@@ -609,12 +609,6 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={"description": "Whether to use bettertransformers"}, json_schema_extra={"description": "Whether to use bettertransformers"},
) )
sage_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention"
},
)
eager_attention: bool | None = None eager_attention: bool | None = None
@@ -1126,27 +1120,6 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_sageattn_wo_sample_packing(cls, data):
if (not data.get("sample_packing", False)) and data.get("sage_attention"):
if not data.get("pad_to_sequence_len", False):
LOG.warning(
"We recommend turning on `pad_to_sequence_len` for SageAttention without packing."
"This is because there has been signs that the loss explodes after a few steps."
)
return data
@model_validator(mode="before")
@classmethod
def check_sageattn_fft(cls, data):
if (not data.get("adapter", False)) and data.get("sage_attention"):
LOG.warning(
"We found loss to drop to 0 with SageAttention full finetuning."
"Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options""" """Wrapper to valdiate GPU capabilities with the configured options"""
@@ -1203,21 +1176,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data return data
@model_validator(mode="before")
@classmethod
def check_compute_capability_w_sageattn(cls, data):
if (
data.get("sage_attention")
and data.get("capabilities")
and data.get("capabilities").get("compute_capability")
not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"]
):
raise ValueError(
"SageAttention supports compute capability between sm_80 and sm_120. "
"Please use a different attention implementation."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_multigpu_unsloth(cls, data): def check_multigpu_unsloth(cls, data):
@@ -1271,10 +1229,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
): ):
return data return data
# Skip if trust_remote_code is enabled, as lora kernels are not compatible
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0: if data.get("lora_dropout") != 0:
return data return data

View File

@@ -120,12 +120,6 @@ class ModelOutputConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "how to push checkpoints to hub"}, json_schema_extra={"description": "how to push checkpoints to hub"},
) )
hub_revision: str | None = Field(
default=None,
json_schema_extra={
"description": "branch/revision to push to on hub (default: main)"
},
)
save_safetensors: bool | None = Field( save_safetensors: bool | None = Field(
default=True, default=True,
json_schema_extra={ json_schema_extra={

View File

@@ -166,10 +166,9 @@ class AttentionValidationMixin:
fields = ( fields = (
"xformers_attention", "xformers_attention",
"sdp_attention", "sdp_attention",
# "s2_attention", # requires both FA and this to be enabled "s2_attention",
"flash_attention", "flash_attention",
"flex_attention", "flex_attention",
"sage_attention",
) )
non_empty_count = sum(1 for field in fields if data.get(field)) non_empty_count = sum(1 for field in fields if data.get(field))
@@ -186,10 +185,9 @@ class AttentionValidationMixin:
and not data.get("sdp_attention") and not data.get("sdp_attention")
and not data.get("flex_attention") and not data.get("flex_attention")
and not data.get("xformers_attention") and not data.get("xformers_attention")
and not data.get("sage_attention")
): ):
LOG.warning( LOG.warning(
"sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination." "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
) )
return data return data
@@ -690,21 +688,6 @@ class LoRAValidationMixin:
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_lora_kernels_trust_remote_code(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("trust_remote_code"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with trust_remote_code. Please disable trust_remote_code "
"or explicitly set lora_*_kernel to false."
)
return data
class RLValidationMixin: class RLValidationMixin:
"""Validation methods related to RL training configuration.""" """Validation methods related to RL training configuration."""

View File

@@ -79,7 +79,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800, "ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25, "ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False, "ddp_broadcast_buffers": False,
"dataset_num_proc": 4, "dataset_processes": 4,
} }
) )

View File

@@ -30,7 +30,7 @@ class TestStreamingDatasets:
"sample_packing": sample_packing, "sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing, "pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000, "streaming_multipack_buffer_size": 10000,
"dataset_num_proc": 1, "dataset_processes": 1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },

View File

@@ -118,6 +118,20 @@ def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
assert not manager.enabled assert not manager.enabled
def test_opt_in_info_displayed(telemetry_manager_class):
"""Test that opt-in info is displayed when telemetry is not configured"""
with (
patch.dict(os.environ, {"RANK": "0"}, clear=True),
patch("logging.Logger.warning") as mock_warning,
patch("time.sleep"),
):
telemetry_manager_class()
assert any(
"Telemetry is now enabled by default" in str(call)
for call in mock_warning.call_args_list
)
def test_is_whitelisted(telemetry_manager_class, mock_whitelist): def test_is_whitelisted(telemetry_manager_class, mock_whitelist):
"""Test org whitelist functionality""" """Test org whitelist functionality"""
with ( with (

View File

@@ -90,62 +90,3 @@ class TestLoRAConfigValidation:
} }
) )
validate_config(invalid_config) validate_config(invalid_config)
@pytest.mark.parametrize(
"kernel_field", ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
def test_lora_kernels_trust_remote_code_incompatible(self, kernel_field):
"""Test that lora kernels are incompatible with trust_remote_code"""
with pytest.raises(ValueError, match="not compatible with trust_remote_code"):
invalid_config = DictDefault(
{
"adapter": "lora",
kernel_field: True,
"trust_remote_code": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_lora_kernels_trust_remote_code_false(self):
"""Test that lora kernels work when trust_remote_code is false"""
# Test with trust_remote_code=False, lora kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"trust_remote_code": False,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_mlp_kernel"] is True
assert result["lora_qkv_kernel"] is True
assert result["lora_o_kernel"] is True
# Test with trust_remote_code=None (unset), kernels should be allowed
valid_config = DictDefault(
{
"adapter": "lora",
"lora_qkv_kernel": True,
"trust_remote_code": None,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None