Compare commits
6 Commits
v0.13.0
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93600fa80d | ||
|
|
b234532d9f | ||
|
|
8990ca3205 | ||
|
|
006f226270 | ||
|
|
0b635e69c5 | ||
|
|
0d27e14e45 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -57,14 +57,14 @@ jobs:
|
|||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
@@ -146,14 +146,14 @@ jobs:
|
|||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
- cuda: "130"
|
- cuda: "130"
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.0
|
pytorch: 2.9.1
|
||||||
torch_cuda_arch_list: "9.0+PTX"
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
20
.github/workflows/main.yml
vendored
20
.github/workflows/main.yml
vendored
@@ -36,6 +36,16 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -109,6 +119,16 @@ jobs:
|
|||||||
pytorch: 2.8.0
|
pytorch: 2.8.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.9.1
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -29,6 +29,7 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
|
- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3).
|
||||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ format:
|
|||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
toc-depth: 3
|
toc-depth: 3
|
||||||
number-sections: true
|
# number-sections: true
|
||||||
code-tools: true
|
code-tools: true
|
||||||
execute:
|
execute:
|
||||||
enabled: false
|
enabled: false
|
||||||
@@ -14,12 +14,18 @@ This guide covers advanced training configurations for multi-GPU setups using Ax
|
|||||||
|
|
||||||
## Overview {#sec-overview}
|
## Overview {#sec-overview}
|
||||||
|
|
||||||
Axolotl supports several methods for multi-GPU training:
|
When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.
|
||||||
|
|
||||||
- DeepSpeed (recommended)
|
You generally cannot combine these strategies; they are mutually exclusive.
|
||||||
- FSDP (Fully Sharded Data Parallel)
|
|
||||||
- Sequence parallelism
|
1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.
|
||||||
- FSDP + QLoRA
|
2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).
|
||||||
|
3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).
|
||||||
|
|
||||||
|
These features can often be combined with the strategies above:
|
||||||
|
|
||||||
|
* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).
|
||||||
|
* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).
|
||||||
|
|
||||||
## DeepSpeed {#sec-deepspeed}
|
## DeepSpeed {#sec-deepspeed}
|
||||||
|
|
||||||
@@ -65,12 +71,18 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
|||||||
|
|
||||||
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
|
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
|
||||||
|
|
||||||
|
FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.
|
||||||
|
|
||||||
::: {.callout-note}
|
::: {.callout-note}
|
||||||
|
|
||||||
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
|
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||||
|
|
||||||
|
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||||
|
|
||||||
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
|
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
|
||||||
|
|
||||||
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
|
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
|
||||||
@@ -145,10 +157,6 @@ single sequence causes OOM errors during model training.
|
|||||||
|
|
||||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||||
|
|
||||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
|
||||||
|
|
||||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
|
||||||
|
|
||||||
## Performance Optimization {#sec-performance}
|
## Performance Optimization {#sec-performance}
|
||||||
|
|
||||||
### Liger Kernel Integration {#sec-liger}
|
### Liger Kernel Integration {#sec-liger}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
46
examples/olmo3/README.md
Normal file
46
examples/olmo3/README.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Finetune Allenai's Olmo 3 with Axolotl
|
||||||
|
|
||||||
|
[Olmo 3](https://huggingface.co/collections/allenai/olmo-3) are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from pip:
|
||||||
|
```bash
|
||||||
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
|
pip3 install packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
|
||||||
|
# Install Cut Cross Entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- The example config can be re-used for Olmo and Olmo 2.
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [Olmo 3 Blog](https://allenai.org/blog/olmo3)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
64
examples/olmo3/olmo3-7b-qlora.yaml
Normal file
64
examples/olmo3/olmo3-7b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: allenai/Olmo-3-7B-Instruct-SFT
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -6,21 +6,17 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Seed-OSS is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from pip:
|
||||||
|
```bash
|
||||||
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
|
pip3 install packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
|
||||||
```bash
|
# Install Cut Cross Entropy
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
```
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install Cut Cross Entropy
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|
||||||
@@ -41,9 +37,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
|
|||||||
@@ -37,9 +37,7 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
|||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
|
||||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ liger-kernel==0.6.3
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.36.0
|
huggingface_hub>=0.36.0
|
||||||
peft>=0.17.1
|
peft>=0.18.0
|
||||||
tokenizers>=0.22.1
|
tokenizers>=0.22.1
|
||||||
transformers==4.57.1
|
transformers==4.57.1
|
||||||
accelerate==1.11.0
|
accelerate==1.11.0
|
||||||
@@ -42,7 +42,6 @@ numpy>=2.2.6
|
|||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.4.2
|
|
||||||
nvidia-ml-py==12.560.30
|
nvidia-ml-py==12.560.30
|
||||||
art
|
art
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -65,6 +65,9 @@ plugins:
|
|||||||
- mistral3
|
- mistral3
|
||||||
- mixtral
|
- mixtral
|
||||||
- mllama
|
- mllama
|
||||||
|
- olmo
|
||||||
|
- olmo2
|
||||||
|
- olmo3
|
||||||
- phi
|
- phi
|
||||||
- phi3
|
- phi3
|
||||||
- phi4_multimodal
|
- phi4_multimodal
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@8a1a0ec"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,29 @@ class LigerPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.liger.LigerArgs"
|
return "axolotl.integrations.liger.LigerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
"""
|
||||||
|
Apply LIGER runtime patches and integrations according to the provided configuration.
|
||||||
|
|
||||||
|
This hook inspects `cfg` and conditionally applies LIGER kernel patches, replacements, and model-specific integrations (rotary embeddings, normalization, GLU variants, and cross-entropy implementations) for the model type indicated by `cfg.model_config_type`. Behavior is driven entirely by the various `cfg.liger_*` flags; the method logs actions and warnings when support is experimental or unavailable.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg: Configuration object containing LIGER-related flags and model identification. Expected attributes include:
|
||||||
|
- model_config_type (str): Target model config type to determine which patches to apply.
|
||||||
|
- base_model (str): Base model identifier used when probing model modules (used for some model types).
|
||||||
|
- trust_remote_code (bool|None): Passed when loading remote model code (used for some model types).
|
||||||
|
- torch_compile (bool): If true, disable torch.compile optimizations for certain LIGER kernels.
|
||||||
|
- liger_cross_entropy (bool)
|
||||||
|
- liger_fused_linear_cross_entropy (bool)
|
||||||
|
- liger_use_token_scaling (bool)
|
||||||
|
- liger_rope (bool)
|
||||||
|
- liger_rms_norm (bool)
|
||||||
|
- liger_layer_norm (bool)
|
||||||
|
- liger_glu_activation (str|bool): Name or flag for GLU/SwiGLU activation selection.
|
||||||
|
(Other LIGER flags referenced by the code may also be consulted.)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If both `cfg.liger_cross_entropy` and `cfg.liger_fused_linear_cross_entropy` are enabled.
|
||||||
|
"""
|
||||||
if cfg.torch_compile:
|
if cfg.torch_compile:
|
||||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||||
import liger_kernel.ops.fused_linear_cross_entropy
|
import liger_kernel.ops.fused_linear_cross_entropy
|
||||||
@@ -168,6 +191,22 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_vl":
|
||||||
|
"""
|
||||||
|
Apply Liger kernels for Qwen3 Vision-Language models.
|
||||||
|
|
||||||
|
Note: The parameter 'swiglu' is used instead of 'glu_activation' to match
|
||||||
|
the Liger kernel API for vision-language models.
|
||||||
|
"""
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_vl(
|
||||||
|
rope=cfg.liger_rope,
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
swiglu=cfg.liger_glu_activation, # Note: qwen3_vl uses swiglu parameter name
|
||||||
|
)
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
@@ -206,4 +245,4 @@ class LigerPlugin(BasePlugin):
|
|||||||
else:
|
else:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
)
|
)
|
||||||
@@ -102,6 +102,8 @@ def load_lora(
|
|||||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||||
if cfg.peft_trainable_token_indices:
|
if cfg.peft_trainable_token_indices:
|
||||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||||
|
if cfg.peft_ensure_weight_tying is not None:
|
||||||
|
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||||
|
|
||||||
# Determine the correct PEFT task type
|
# Determine the correct PEFT task type
|
||||||
model_cls = type(model).__name__
|
model_cls = type(model).__name__
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"seed_oss",
|
"seed_oss",
|
||||||
"lfm2",
|
"lfm2",
|
||||||
"lfm2_moe",
|
"lfm2_moe",
|
||||||
|
"olmo",
|
||||||
|
"olmo2",
|
||||||
|
"olmo3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,15 @@ class LoraConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
peft_ensure_weight_tying: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Whether to tie adapter weights for tied model weights. "
|
||||||
|
"See https://github.com/huggingface/peft/issues/2864"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user