Compare commits

...

51 Commits

Author SHA1 Message Date
Sunny Liu
0179021780 fix attribute error 2025-04-21 22:29:24 -04:00
Sunny Liu
c4910da015 update more tests + better hqq validation 2025-04-21 22:17:08 -04:00
Sunny Liu
db7e92f6a6 check if self.cfg.quantization exists when directly setting load_in_4bit 2025-04-21 21:42:23 -04:00
Sunny Liu
136b37e4d4 restore support for legacy cfg.load_in_xbit 2025-04-21 21:32:01 -04:00
Sunny Liu
92644513c4 update relora 2025-04-21 21:22:44 -04:00
Sunny Liu
266ef3f479 skip set_quant_config if quantization not given 2025-04-21 17:17:41 -04:00
Sunny Liu
fcef8c95fe skip set_quant_config if quantization not given 2025-04-21 17:17:20 -04:00
Sunny Liu
136407c556 update multigpu/test_qwen2 2025-04-21 17:04:17 -04:00
Sunny Liu
3251b3235f update test_mixtral 2025-04-21 17:01:07 -04:00
Sunny Liu
1aa9f7d952 update multigpu/test_eval, multigpu/test_llama 2025-04-21 16:49:08 -04:00
Sunny Liu
a20e753321 update test_falcon_samplepack 2025-04-21 16:29:49 -04:00
Sunny Liu
cb121ab91b update test_mixtral [skip e2e] 2025-04-21 16:27:26 -04:00
Sunny Liu
b59640a4c7 amend model loading for hqq + fix hqq version 2025-04-21 15:53:43 -04:00
Sunny Liu
f0a189131b amend model loading for hqq + fix hqq version 2025-04-21 15:53:29 -04:00
Sunny Liu
c8fb5baad6 amend unittests pt2 2025-04-21 13:28:52 -04:00
Sunny Liu
9be971d47c update test_models.py to conform to new quantization config 2025-04-21 11:34:37 -04:00
Sunny Liu
ffd4ef1ece nit 2025-04-21 11:28:59 -04:00
Sunny Liu
320aff1867 update config doc 2025-04-21 10:59:04 -04:00
Sunny Liu
ac24eba2ac include HQQLinear in find target_linear 2025-04-21 10:36:39 -04:00
Sunny Liu
8a5ad8aee3 typo 2025-04-21 10:36:39 -04:00
Sunny Liu
843b50fdaa rigorous qlora validation 2025-04-21 10:36:39 -04:00
Sunny Liu
098ffcc5a2 removed redundant hqq config validation 2025-04-21 10:36:39 -04:00
Sunny Liu
ba8e29c841 quantization config refactoring - better integration 2025-04-21 10:36:39 -04:00
Sunny Liu
143b2e082c nit [skip e2e] 2025-04-21 10:36:39 -04:00
Sunny Liu
aba484de97 WIP quant config refactor 2025-04-21 10:36:39 -04:00
Sunny Liu
f6f5f89c6d fix more typo 2025-04-21 10:36:39 -04:00
Sunny Liu
8926fe9981 lax config requirement - qlora + hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
987c5217a0 fix typos 2025-04-21 10:36:39 -04:00
Sunny Liu
feaef03cb9 didn't realise model_config.quantization_config is just a regular dict 2025-04-21 10:36:39 -04:00
Sunny Liu
ba5d917845 add e2e test for hqq training 2025-04-21 10:36:39 -04:00
Sunny Liu
0e9b060b4d add doc + requirement for hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
0c40d12a18 more comprehensive hqq config options 2025-04-21 10:36:39 -04:00
Sunny Liu
f55b3c805b hqq_nbits triggers prepare_model_for_kbit_training 2025-04-21 10:36:39 -04:00
Sunny Liu
a64601f957 fix wrong variable name 2025-04-21 10:36:39 -04:00
Sunny Liu
eb7bc70b99 fix dumb mistake 2025-04-21 10:36:39 -04:00
Sunny Liu
db6c76b147 forgot to return data in check 2025-04-21 10:36:39 -04:00
Sunny Liu
99730ce40a hqq integration 2025-04-21 10:36:39 -04:00
Wing Lian
7651550850 make sure to download fixtures for kd test (#2541)
* make sure to download fixtures for kd test

* use same alpaca dataset
2025-04-21 10:31:50 -04:00
Wing Lian
341e95aac9 prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] 2025-04-21 10:31:35 -04:00
Catgat
b882dfb63f Fixed Rex Scheduler Warm Up (#2535) [skip ci]
* Fixed Rex Scheduler Warm Up

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-21 10:30:55 -04:00
Wing Lian
b640db1dbc don't run multigpu tests twice, run SP in separate test (#2542)
* don't run multigpu tests twice, run SP in separate test

* fix multiline
2025-04-21 10:24:13 -04:00
Chiwan Park
4ce469d32e fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
2025-04-18 09:57:40 -07:00
Wing Lian
60a8f0958d zero val fix for beta (#2538) 2025-04-17 17:27:19 -07:00
NanoCode012
9da730d6a4 fix(doc): cut cross entropy installation instructions broken in qmd (#2532) 2025-04-16 15:02:51 -07:00
NanoCode012
32637fad00 fix: preprocess yielding whole dataset to each worker (#2503) [skip ci] 2025-04-16 15:02:35 -07:00
Dan Saunders
f776f889a1 adding codecov reporting (#2372) [skip ci]
* adding codecov reporting

* update codecov-action to v5

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-16 15:02:17 -07:00
Wing Lian
69eda209a6 re-enable DS zero3 ci with updated transformers (#2533) 2025-04-16 14:48:40 -07:00
Dan Saunders
b8c633aa97 batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* removing pad_to_sequence_len=False for now

* fix

* updating docs to include batch SP

* review comments

* fixes for batch API funcs, simplify

* fixes

* fix

* updates

* add batch_zigzag smoke test
2025-04-16 13:50:48 -04:00
NanoCode012
682a9cf79b Fix: add delinearization and make qlora work with fsdp2 (#2515)
* fixes for delinearization, and make qlora work with fsdp2

* Add back mistakenly removed lm_eval

* typo [skip ci]

* patch evals for torch.compile + fsdp2

* also check torch_compile w fsdp2

* lots of fixes for flex attn with llama4

* fix patch check and patch llama4 too

* attempt to make the patches stick

* use transformers 4.51.2

* update configs and README for llama4

* remove torch.compile for CI test

* cleanup any existing singletons

* set singleton cache to None instead of deleting

* use importlib reload with monkeypatch

* don't worry about transformers version, mark inputs with grads, fix regex

* make sure embeds aren't on cpu

* logging and mem improvements

* vllm version and add to docker, make sure to save processor on conversion

* fix ambiguous tensor bool check

* fix vllm to not use v1, upgrade hf transformers

* fix tests

* make flex_attn_compile_kwargs configurable, since this depends on model params

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-04-15 23:31:39 -07:00
NanoCode012
271b24cccc feat: update cce to latest (#2521) 2025-04-15 22:17:10 -07:00
Wing Lian
198d775d6d make sure the all of the model is on the same device, so this test will pass on multigpu (#2524) [skip ci] 2025-04-15 22:15:42 -07:00
73 changed files with 1757 additions and 270 deletions

14
.coveragerc Normal file
View File

@@ -0,0 +1,14 @@
[run]
source = axolotl
omit =
*/tests/*
setup.py
[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
pass
raise ImportError

View File

@@ -29,7 +29,7 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:

View File

@@ -102,9 +102,16 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |

View File

@@ -9,6 +9,7 @@
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>

View File

@@ -3,10 +3,59 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 /workspace/axolotl/tests/cli
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \
--ignore=tests/patched/ \
--ignore=tests/cli \
/workspace/axolotl/tests/ \
--cov=axolotl \
--cov-report=xml:coverage.xml
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/patched/lora_kernels \
--cov=axolotl \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \
--cov-append
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/solo/ \
--cov=axolotl \
--cov-append
# Run integration tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/integrations/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 /workspace/axolotl/tests/cli \
--cov=axolotl \
--cov-append
# Run remaining e2e tests with coverage append and final report
pytest -v --durations=10 \
--ignore=tests/e2e/solo/ \
--ignore=tests/e2e/patched/ \
--ignore=tests/e2e/multigpu/ \
--ignore=tests/e2e/integrations/ \
--ignore=tests/cli \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:coverage.xml
# Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -1,6 +1,27 @@
#!/bin/bash
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

51
codecov.yml Normal file
View File

@@ -0,0 +1,51 @@
codecov:
require_ci_to_pass: yes
coverage:
precision: 2
round: down
range: "70...100"
status:
project:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
patch:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
parsers:
gcov:
branch_detection:
conditional: yes
loop: yes
method: no
macro: no
comment:
layout: "reach,diff,flags,files,footer"
behavior: default
require_changes: no
require_base: no
require_head: yes

View File

@@ -55,20 +55,46 @@ overrides_of_model_config:
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq:
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
load_in_8bit:
# Use bitsandbytes 4 bit
load_in_4bit:
@@ -693,6 +719,9 @@ sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
# in the sample packing case, and "batch_ring" in the non-sample packing case.
ring_attn_func:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

View File

@@ -1,16 +1,28 @@
# Llama 4 by Meta AI
## Flash Attention vs Flex Attention
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
## Available Examples
### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
Flex Attention
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
[//]: # (Flash Attention &#40;Do not use&#41;)
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1]&#40;./scout-vision-qlora-fsdp.yaml&#41;)
[//]: # (- [Text Single GPU &#40;H100&#41; QLoRA]&#40;./scout-qlora-single-h100.yaml&#41;)
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1]&#40;./scout-qlora-fsdp1.yaml&#41;)
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
### Llama 4 Maverick 17Bx128Experts (400B)
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)
Coming Soon

View File

@@ -0,0 +1,86 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -0,0 +1,85 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head # needed if modifying vocabulary
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
torch_compile: true
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -0,0 +1,89 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
sequence_len: 4096
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -1,6 +1,6 @@
pre-commit
black
mypy
pre-commit
types-requests
quartodoc
jupyter

View File

@@ -1,5 +1,7 @@
codecov
pytest
pytest-xdist
pytest-cov
pytest-retry
pytest-sugar
pytest-xdist
tbparse

View File

@@ -6,13 +6,13 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.6
liger-kernel==0.5.8
# END section
packaging==23.2
peft==0.15.1
transformers==4.51.1
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
@@ -22,6 +22,7 @@ hf_xet==1.0.0
optimum==1.16.2
hf_transfer
hqq==0.2.5
sentencepiece
gradio==5.23.3

View File

@@ -25,5 +25,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
)

View File

@@ -67,7 +67,7 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -0,0 +1,156 @@
"""
CLI tool to delinearize quantized/Linearized Llama-4 models.
"""
import os
from pathlib import Path
from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:
keys = list(model_state_dict.keys())
for key in keys:
if ".feed_forward.experts." not in key:
yield key, model_state_dict[key]
if ".feed_forward.experts.gate_projs" in key:
# gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up
continue
if ".feed_forward.experts.up_projs" in key:
if ".feed_forward.experts.up_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".up_projs.0.")[0]
key = f"{prefix}.gate_up_proj"
# grab all the up_projs and gate_projs across all experts
gate_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.gate_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
up_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.up_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)
del gate_stacked, up_stacked
yield key, gate_up_proj
else:
del model_state_dict[key]
continue
if ".feed_forward.experts.down_projs" in key:
if ".feed_forward.experts.down_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".down_projs.0.")[0]
key = f"{prefix}.down_proj"
# grab all the down_projs across all experts
down_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.down_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
yield key, down_stacked
else:
del model_state_dict[key]
continue
def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
"""
Convert a patched HF format Llama4 model (with separated projections)
back to the original HF format (with fused projections).
Args:
model: Path to the patched HF model
output: Path to save the converted model
"""
print(f"Loading model from {model}")
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output)
device = model_.device.type
if device == "cuda":
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_config = model_.config
config = model_.config.get_text_config()
# Get key dimensions from the config
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
num_experts = config.num_local_experts
print(
f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}"
)
# Create output directory if it doesn't exist
os.makedirs(output, exist_ok=True)
# Get state dict
state_dict = model_.state_dict()
del model_
# Create a new state dict for the converted model
converted_state_dict = {}
# First, copy all keys that don't need modification
for key, value in iter_convert_patched_to_hf(state_dict, num_experts):
converted_state_dict[key] = value
del state_dict
if device == "cuda":
torch.cuda.empty_cache()
print("State dict converted.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
# Ideally re-load the model import to load the converted state dict
# Save the converted model
with init_empty_weights():
unpatch_llama4()
model_ = Llama4ForConditionalGeneration(model_config)
if device == "cuda":
print("State dict loaded into model.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_.load_state_dict(converted_state_dict, strict=False, assign=True)
print(f"Saving converted model to {output}...")
model_.save_pretrained(output)
print(f"Model successfully converted and saved to {output}")
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -330,6 +330,15 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
cli.add_command(lm_eval)

View File

@@ -40,6 +40,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.warning("Error raised: %s", e)
model.generation_config.do_sample = True
model.config.use_cache = True
if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")

View File

@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator(
*collator_args,
@@ -1038,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -9,6 +9,8 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
# multi-modal section

View File

@@ -12,12 +12,14 @@ See https://github.com/apple/ml-cross-entropy
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
- If you are in dev environment
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
```
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
```
## Usage

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
)

View File

@@ -165,7 +165,7 @@ def cce_forward(
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
input_ids: torch.LongTensor | None = None, # type: ignore
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -254,7 +254,7 @@ def cce_forward_multimodal(
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
if pixel_values is not None:
image_features = self.get_image_features(
@@ -263,13 +263,13 @@ def cce_forward_multimodal(
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
original_inputs_embeds_shape = inputs_embeds.shape
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device)
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
final_mask_1d = final_mask[..., 0].reshape(-1)

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- gemma3
- granite
- jamba
- llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,

View File

@@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd)
model.load_state_dict(sharded_sd, assign=True)
def patch_accelerate_fsdp_utils():

View File

@@ -7,12 +7,11 @@ import torch
import transformers
def patch_flex_wrapper():
def patch_flex_wrapper(**flex_attn_compile_kwargs):
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51):
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import flex_attention
@@ -32,17 +31,24 @@ def patch_flex_wrapper():
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def del_singleton(cls):
cls._instance = None
@torch.compiler.disable(recursive=False)
def __init__(self):
def __init__(self, training):
"""
Initialize or update the singleton instance.
"""
if not self._is_flex_compiled:
self.training = None
if not self._is_flex_compiled or training != self.training:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
self._compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
**flex_attn_compile_kwargs,
)
self._is_flex_compiled = True
@@ -50,15 +56,22 @@ def patch_flex_wrapper():
return self._compiled_flex_attention
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
setattr(
sys.modules["transformers.integrations.flex_attention"],
"WrappedFlexAttention",
WrappedFlexAttention,
)
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_eq_4_51):
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
@@ -104,14 +117,16 @@ def patch_flex_make_mask():
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d, value=0, pad=(0, key_length)
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
@@ -138,6 +153,18 @@ def patch_flex_make_mask():
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
@@ -145,10 +172,10 @@ def patch_flex_make_mask():
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
@@ -160,11 +187,16 @@ def patch_flex_make_mask():
)
for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n:
if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
setattr(
sys.modules[n],
"make_flex_block_causal_mask",
patched_make_flex_block_causal_mask,
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask

View File

@@ -0,0 +1,12 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -0,0 +1,192 @@
"""
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
Our implementation closely follows the structure of that module, but we've minified it
somewhat to support only the latest versions of transformers.
"""
# pylint: disable=protected-access,cyclic-import
import os
from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
}
def create_flash_attn_forward(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
Create a ring flash attention forward function compatible with HuggingFace's
interface.
Args:
process_group: A PyTorch distributed process group.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
Returns:
A function that implements the ring flash attention forward pass with the
signature expected by HuggingFace Transformers.
"""
# transformers 4.48+
# pylint: disable=unused-argument
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: torch.Tensor | None = None,
softmax_scale: float | None = None,
sliding_window: int | None = None,
use_top_left_mask: bool = False,
softcap: float | None = None,
deterministic: bool = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
**kwargs,
):
"""
Calls the forward method of Ring Flash Attention.
Args:
query_states: Tensor containing the query vectors.
key_states: Tensor containing the key vectors.
value_states: Tensor containing the value vectors.
attention_mask: Not used in this implementation.
query_length: Integer representing the length of the query sequence.
is_causal: Boolean indicating whether to apply a causal mask to the attention.
dropout: Float representing the dropout probability. Default is 0.0.
position_ids: Not used in this implementation.
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
sliding_window: Optional integer defining the size of the sliding attention window.
Default is None.
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
Default is False.
softcap: Not used in this implementation.
deterministic: Optional boolean to enforce deterministic computation. Default is None.
cu_seq_lens_q: Not used in this implementation.
cu_seq_lens_k: Not used in this implementation.
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
torch.Tensor: The output of the attention mechanism, with shape
`[batch_size, query_length, num_heads, head_dim]`.
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
window_size = (
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
)
# Handle deterministic mode
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = (
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
# Call ring flash attention function
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
query_states,
key_states,
value_states,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
group=process_group,
)
return attn_output
return _flash_attention_forward
def substitute_hf_flash_attn(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
):
"""
Substitute HuggingFace's flash attention implementation with ring-based implementation.
Args:
process_group: PyTorch distributed process group for communication.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
"""
try:
# Substitute flash attention
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
process_group=process_group, ring_attn_func=ring_attn_func
)
if check_params(old_flash_attention_forward, new_flash_attention_forward):
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
else:
raise ValueError(
"The signature of the new flash attention forward function does not match the old one."
)
except Exception as exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"Please use pip install -U transformers to upgrade to the latest version. "
"If the code failed with the latest version, "
f"please file an issue."
) from exception
# Register with ALL_ATTENTION_FUNCTIONS if available
if ALL_ATTENTION_FUNCTIONS is not None:
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

View File

@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2.
"""
from enum import Enum
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging()
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
rank = dist.get_rank()
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"must evenly divide world_size ({world_size})"
)
# Detailed logging of group formation
rank = dist.get_rank()
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,
RingAttnFunc.BATCH_STRIPE,
]:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
def update_ring_attn_params(position_ids: torch.Tensor | None):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
batch: A dictionary with a batch of data. May or may not contain `position_ids`
data; if not, we compute it.
position_ids: Optional tensor of position IDs (for sample packed data).
"""
from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -93,9 +93,20 @@ def patch_llama4_linearized_modeling():
"""
from transformers.models.llama4 import modeling_llama4
old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)
def unpatch():
modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
old_lamma_4_text_experts,
)
return unpatch

View File

@@ -0,0 +1,78 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -81,6 +81,11 @@ def setup_model_and_tokenizer(
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
if any(
any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters
):
model.enable_input_require_grads()
return model, tokenizer, peft_config, processor

View File

@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any
import numpy as np
import torch
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
model: Any | None = None
padding: bool | str | PaddingStrategy = True
max_length: int | None = None
pad_to_multiple_of: int | None = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1]
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
total_seq_len = batch["input_ids"].size(1)
# Update params for ring attention calculation
update_ring_attn_params(batch=batch)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
batch[key] = batch[key][:, start:end]
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch

View File

@@ -236,6 +236,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg):
"""

View File

@@ -3,6 +3,7 @@
import functools
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
# other ranks, we just need to present a fake dataset
if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
@@ -332,16 +351,23 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
def gen_from_iter_ds(_ds, _=None):
yield from _ds
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
"""Generator function to correctly splice the dataset for each worker"""
for i, item in enumerate(_ds):
if i % num_workers[0] == worker_id[0]:
yield item
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
num_proc=num_workers,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:

View File

@@ -2,13 +2,14 @@
module to freeze/unfreeze parameters by name
"""
import logging
import re
from typing import Callable, List, Tuple, Union
from accelerate.logging import get_logger
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze")
LOG = get_logger(__name__)
def freeze_layers_except(model, regex_patterns):
@@ -184,7 +185,7 @@ class LayerNamePattern:
"""
self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern))
def match(self, name: str) -> bool:
"""

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
@@ -542,6 +543,17 @@ class ModelLoader:
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
@@ -644,6 +656,7 @@ class ModelLoader:
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
)
def patch_attention(self) -> None:
@@ -821,6 +834,13 @@ class ModelLoader:
del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -842,21 +862,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
):
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
}
quant_config_class = quant_config_class_dict[
self.model_config.quantization_config["quant_method"]
]
self.model_kwargs["quantization_config"] = quant_config_class(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
@@ -874,8 +894,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
@@ -891,6 +911,13 @@ class ModelLoader:
**bnb_config,
)
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
@@ -905,13 +932,6 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
@@ -1031,6 +1051,12 @@ class ModelLoader:
config=self.model_config,
)
else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1115,7 +1141,7 @@ class ModelLoader:
return skip_move_to_device
def ajust_model_config(self) -> None:
def adjust_model_config(self) -> None:
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
@@ -1185,7 +1211,7 @@ class ModelLoader:
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1275,7 +1301,7 @@ class ModelLoader:
else:
self.model.tie_weights()
self.ajust_model_config()
self.adjust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in (
@@ -1455,7 +1481,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
self.last_step = max(last_step - 1, 0)
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:

View File

@@ -225,6 +225,7 @@ class AxolotlInputConfig(
sdp_attention: bool | None = None
s2_attention: bool | None = None
flex_attention: bool | None = None
flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = None
flash_attn_cross_entropy: bool | None = None
flash_attn_rms_norm: bool | None = None
@@ -258,6 +259,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: str | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -658,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"
@@ -1146,7 +1149,7 @@ class AxolotlInputConfig(
return data
@field_validator("sequence_parallel_degree", mode="before")
@field_validator("sequence_parallel_degree", mode="after")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
@@ -1158,9 +1161,12 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if not info.data["micro_batch_size"] == 1:
if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError(
"micro_batch_size must be set to 1 "
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"
)
@@ -1187,6 +1193,34 @@ class AxolotlInputConfig(
return value
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
@@ -1276,11 +1310,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
if capabilities and capabilities.get("n_gpu", 0) > 1:
is_fsdp2 = (
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
)
return data

View File

@@ -1,9 +1,9 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.quant import QuantizationConfig
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None
lora_model_dir: str | None = None
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
and (data.get("quantization"))
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
"Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization."
)
return data
@@ -86,25 +87,26 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.load_in_4bit:
if self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.quantization:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@@ -121,6 +123,24 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -0,0 +1,93 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -235,7 +236,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type in ["mamba", "gemma3"]:
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -625,6 +627,12 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if (
cfg.torch_compile
and cfg.fsdp_config
and str(cfg.fsdp_config.fsdp_version) == "2"
):
patch_evaluation_loop_for_fsdp2()
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref

View File

@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(

View File

View File

@@ -3,13 +3,14 @@
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
def test_sequence_parallel_training(self, temp_dir):
# pylint: disable=duplicate-code
def _run_sequence_parallel_test(
self,
temp_dir,
sample_packing=True,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"sample_packing": sample_packing,
"eval_sample_packing": sample_packing,
"pad_to_sequence_len": pad_to_sequence_len,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
],
"num_epochs": 1,
"max_steps": 8,
"micro_batch_size": 1,
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
self,
temp_dir,
sample_packing,
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
)

View File

@@ -0,0 +1,2 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,18 +49,20 @@ class TestPackedFlex:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
}

View File

@@ -177,6 +177,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
cfg.base_model,
@@ -264,6 +265,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
cfg.base_model,

View File

@@ -30,8 +30,10 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
@@ -99,8 +101,10 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",

View File

@@ -171,7 +171,10 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -249,7 +252,10 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
@@ -548,7 +554,10 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -621,12 +630,6 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
@@ -654,7 +657,10 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}
@@ -728,7 +734,10 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}
@@ -802,7 +811,10 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}

View File

@@ -28,7 +28,10 @@ class TestMultiGPUQwen2:
cfg = DictDefault(
{
"base_model": base_model,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
)
# Apply LoRA configuration

View File

@@ -32,7 +32,10 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,

View File

@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,

View File

@@ -33,7 +33,10 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,

View File

@@ -46,8 +46,9 @@ class TestResumeLlama:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 2,

View File

@@ -73,7 +73,10 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
@@ -82,7 +85,11 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2

View File

@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -34,7 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,

View File

@@ -35,7 +35,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,
@@ -91,7 +94,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,

View File

@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -0,0 +1,141 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora",
}
| minimal_cfg
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
"deepspeed": "",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
}
| minimal_cfg
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
"deepspeed": None,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
}
| minimal_cfg
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
}
)
| base_cfg
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
"quantization": {
"backend": "gptq",
},
}
)
| base_cfg
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": False,
"quantization": {
"bits": None,
},
}
)
| base_cfg
)
with pytest.raises(ValueError, match=r".*4bit.*"):
with pytest.raises(ValueError, match=r".*bits <= 4*"):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
)
| base_cfg
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
}
)
| base_cfg
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
"quantization": {
"backend": "gptq",
"bits": 4,
},
}
)
| base_cfg
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
"quantization": {
"bits": 4,
},
}
)
| base_cfg
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"load_in_4bit": True,
"quantization": {
"bits": None,
},
}
)
| minimal_cfg
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
match=r"Quantization is not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_8bit": True,
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"quantization": {
"bits": 4,
},
"adapter": "qlora",
}
)
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"load_in_8bit": True,
"quantization": {
"bits": 8,
},
"adapter": "lora",
}
)

View File

@@ -21,8 +21,10 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,