Compare commits
51 Commits
transforme
...
feat_hqq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0179021780 | ||
|
|
c4910da015 | ||
|
|
db7e92f6a6 | ||
|
|
136b37e4d4 | ||
|
|
92644513c4 | ||
|
|
266ef3f479 | ||
|
|
fcef8c95fe | ||
|
|
136407c556 | ||
|
|
3251b3235f | ||
|
|
1aa9f7d952 | ||
|
|
a20e753321 | ||
|
|
cb121ab91b | ||
|
|
b59640a4c7 | ||
|
|
f0a189131b | ||
|
|
c8fb5baad6 | ||
|
|
9be971d47c | ||
|
|
ffd4ef1ece | ||
|
|
320aff1867 | ||
|
|
ac24eba2ac | ||
|
|
8a5ad8aee3 | ||
|
|
843b50fdaa | ||
|
|
098ffcc5a2 | ||
|
|
ba8e29c841 | ||
|
|
143b2e082c | ||
|
|
aba484de97 | ||
|
|
f6f5f89c6d | ||
|
|
8926fe9981 | ||
|
|
987c5217a0 | ||
|
|
feaef03cb9 | ||
|
|
ba5d917845 | ||
|
|
0e9b060b4d | ||
|
|
0c40d12a18 | ||
|
|
f55b3c805b | ||
|
|
a64601f957 | ||
|
|
eb7bc70b99 | ||
|
|
db6c76b147 | ||
|
|
99730ce40a | ||
|
|
7651550850 | ||
|
|
341e95aac9 | ||
|
|
b882dfb63f | ||
|
|
b640db1dbc | ||
|
|
4ce469d32e | ||
|
|
60a8f0958d | ||
|
|
9da730d6a4 | ||
|
|
32637fad00 | ||
|
|
f776f889a1 | ||
|
|
69eda209a6 | ||
|
|
b8c633aa97 | ||
|
|
682a9cf79b | ||
|
|
271b24cccc | ||
|
|
198d775d6d |
14
.coveragerc
Normal file
14
.coveragerc
Normal file
@@ -0,0 +1,14 @@
|
||||
[run]
|
||||
source = axolotl
|
||||
omit =
|
||||
*/tests/*
|
||||
setup.py
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
def __repr__
|
||||
raise NotImplementedError
|
||||
if __name__ == .__main__.:
|
||||
pass
|
||||
raise ImportError
|
||||
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
@@ -102,9 +102,16 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||
<br/>
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
|
||||
|
||||
63
cicd/cicd.sh
63
cicd/cicd.sh
@@ -3,10 +3,59 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/e2e/ \
|
||||
--ignore=tests/patched/ \
|
||||
--ignore=tests/cli \
|
||||
/workspace/axolotl/tests/ \
|
||||
--cov=axolotl \
|
||||
--cov-report=xml:coverage.xml
|
||||
|
||||
# Run lora kernels tests with coverage append
|
||||
pytest -v --durations=10 \
|
||||
/workspace/axolotl/tests/e2e/patched/lora_kernels \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run patched tests excluding lora kernels with coverage append
|
||||
pytest -v --durations=10 \
|
||||
--ignore=tests/e2e/patched/lora_kernels \
|
||||
/workspace/axolotl/tests/e2e/patched \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run solo tests with coverage append
|
||||
pytest -v --durations=10 -n1 \
|
||||
/workspace/axolotl/tests/e2e/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run integration tests with coverage append
|
||||
pytest -v --durations=10 \
|
||||
/workspace/axolotl/tests/e2e/integrations/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run remaining e2e tests with coverage append and final report
|
||||
pytest -v --durations=10 \
|
||||
--ignore=tests/e2e/solo/ \
|
||||
--ignore=tests/e2e/patched/ \
|
||||
--ignore=tests/e2e/multigpu/ \
|
||||
--ignore=tests/e2e/integrations/ \
|
||||
--ignore=tests/cli \
|
||||
/workspace/axolotl/tests/e2e/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:coverage.xml
|
||||
|
||||
# Upload coverage to Codecov
|
||||
if [ -f e2e-coverage.xml ]; then
|
||||
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
|
||||
else
|
||||
echo "Coverage file not found. Coverage report may have failed."
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,27 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
--cov=axolotl \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
# Upload coverage to Codecov
|
||||
if [ -f multigpu-coverage.xml ]; then
|
||||
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
||||
else
|
||||
echo "Coverage file not found. Coverage report may have failed."
|
||||
fi
|
||||
|
||||
51
codecov.yml
Normal file
51
codecov.yml
Normal file
@@ -0,0 +1,51 @@
|
||||
codecov:
|
||||
require_ci_to_pass: yes
|
||||
|
||||
coverage:
|
||||
precision: 2
|
||||
round: down
|
||||
range: "70...100"
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
# basic
|
||||
target: auto
|
||||
threshold: 0%
|
||||
base: auto
|
||||
# advanced
|
||||
branches: null
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
patch:
|
||||
default:
|
||||
# basic
|
||||
target: auto
|
||||
threshold: 0%
|
||||
base: auto
|
||||
# advanced
|
||||
branches: null
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
|
||||
parsers:
|
||||
gcov:
|
||||
branch_detection:
|
||||
conditional: yes
|
||||
loop: yes
|
||||
method: no
|
||||
macro: no
|
||||
|
||||
comment:
|
||||
layout: "reach,diff,flags,files,footer"
|
||||
behavior: default
|
||||
require_changes: no
|
||||
require_base: no
|
||||
require_head: yes
|
||||
@@ -55,20 +55,46 @@ overrides_of_model_config:
|
||||
overrides_of_model_kwargs:
|
||||
# use_cache: False
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
# These are default values
|
||||
llm_int8_has_fp16_weight: false
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
|
||||
# Quantization configuration.
|
||||
quantization:
|
||||
backend: bnb | hqq | gptq
|
||||
bits: 8
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
# These are default values
|
||||
llm_int8_has_fp16_weight: false
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
|
||||
hqq_config:
|
||||
# pick one of the following, depending on if you want to uniformly quantize the whole model or
|
||||
# apply different quantization settings to specific layers in the model:
|
||||
|
||||
# if uniformly quantize the whole model:
|
||||
group_size: 64
|
||||
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
|
||||
- nbits: 4
|
||||
group_size: 64
|
||||
target_modules:
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- nbits: 3
|
||||
group_size: 32
|
||||
target_modules:
|
||||
- mlp.gate_proj
|
||||
- mlp.up_proj
|
||||
- mlp.down_proj
|
||||
|
||||
# (Internal Use Only)
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
|
||||
gptq:
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
load_in_8bit:
|
||||
# Use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
@@ -693,6 +719,9 @@ sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
|
||||
# in the sample packing case, and "batch_ring" in the non-sample packing case.
|
||||
ring_attn_func:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
# Llama 4 by Meta AI
|
||||
|
||||
## Flash Attention vs Flex Attention
|
||||
|
||||
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
|
||||
|
||||
## Available Examples
|
||||
|
||||
### Llama 4 Scout 17Bx16Experts (109B)
|
||||
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
|
||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
|
||||
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
|
||||
|
||||
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
|
||||
Flex Attention
|
||||
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
|
||||
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
|
||||
|
||||
[//]: # (Flash Attention (Do not use))
|
||||
|
||||
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml))
|
||||
|
||||
[//]: # (- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml))
|
||||
|
||||
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml))
|
||||
|
||||
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
|
||||
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
|
||||
|
||||
### Llama 4 Maverick 17Bx128Experts (400B)
|
||||
|
||||
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
|
||||
|
||||
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)
|
||||
Coming Soon
|
||||
|
||||
86
examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
Normal file
86
examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
|
||||
llama4_linearized_experts: true
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
# - experts.gate_projs.[0-9]+$
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
# - lm_head
|
||||
# - embed_tokens
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- auto_wrap
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
85
examples/llama-4/scout-qlora-single-h100-flex.yaml
Normal file
85
examples/llama-4/scout-qlora-single-h100-flex.yaml
Normal file
@@ -0,0 +1,85 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
cut_cross_entropy: true
|
||||
|
||||
llama4_linearized_experts: true # needed with custom linearized experts model
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
# - lm_head # needed if modifying vocabulary
|
||||
# - embed_tokens
|
||||
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096 # up to 8k will work on a single H100
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
torch_compile: true
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
gradient_checkpointing: offload
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
logging_steps: 1
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
89
examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
Normal file
89
examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
|
||||
model_type: Llama4ForConditionalGeneration
|
||||
processor_type: Llama4Processor
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
sequence_len: 4096
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_glu_activation: true
|
||||
liger_rms_norm: true
|
||||
liger_layer_norm: true
|
||||
|
||||
llama4_linearized_experts: true # use Axolotl's customized model
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_modules:
|
||||
- self_attn.q_proj
|
||||
- self_attn.k_proj
|
||||
- self_attn.v_proj
|
||||
- self_attn.o_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert.up_proj
|
||||
- shared_expert.down_proj
|
||||
- vision_adapter.mlp.fc1
|
||||
- vision_adapter.mlp.fc2
|
||||
# - experts.gate_projs.[0-9]+$
|
||||
# - experts.up_projs.[0-9]+$
|
||||
# - experts.down_projs.[0-9]+$
|
||||
lora_modules_to_save:
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- auto_wrap
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot|>
|
||||
@@ -1,6 +1,6 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
pre-commit
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
codecov
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-cov
|
||||
pytest-retry
|
||||
pytest-sugar
|
||||
pytest-xdist
|
||||
tbparse
|
||||
|
||||
@@ -6,13 +6,13 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.6
|
||||
liger-kernel==0.5.8
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.1
|
||||
transformers==4.51.1
|
||||
transformers==4.51.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
@@ -22,6 +22,7 @@ hf_xet==1.0.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
hqq==0.2.5
|
||||
sentencepiece
|
||||
gradio==5.23.3
|
||||
|
||||
|
||||
@@ -25,5 +25,5 @@ if cce_spec:
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
|
||||
)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -67,7 +67,7 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
|
||||
156
src/axolotl/cli/delinearize_llama4.py
Normal file
156
src/axolotl/cli/delinearize_llama4.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
CLI tool to delinearize quantized/Linearized Llama-4 models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Generator, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:
|
||||
keys = list(model_state_dict.keys())
|
||||
for key in keys:
|
||||
if ".feed_forward.experts." not in key:
|
||||
yield key, model_state_dict[key]
|
||||
if ".feed_forward.experts.gate_projs" in key:
|
||||
# gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up
|
||||
continue
|
||||
if ".feed_forward.experts.up_projs" in key:
|
||||
if ".feed_forward.experts.up_projs.0." in key:
|
||||
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
|
||||
prefix = key.split(".up_projs.0.")[0]
|
||||
key = f"{prefix}.gate_up_proj"
|
||||
# grab all the up_projs and gate_projs across all experts
|
||||
gate_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.gate_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
up_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.up_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)
|
||||
del gate_stacked, up_stacked
|
||||
yield key, gate_up_proj
|
||||
else:
|
||||
del model_state_dict[key]
|
||||
continue
|
||||
if ".feed_forward.experts.down_projs" in key:
|
||||
if ".feed_forward.experts.down_projs.0." in key:
|
||||
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
|
||||
prefix = key.split(".down_projs.0.")[0]
|
||||
key = f"{prefix}.down_proj"
|
||||
# grab all the down_projs across all experts
|
||||
down_stacked = torch.stack(
|
||||
[
|
||||
model_state_dict[
|
||||
f"{prefix}.down_projs.{expert_idx}.weight"
|
||||
].transpose(0, 1)
|
||||
for expert_idx in range(num_experts)
|
||||
]
|
||||
)
|
||||
yield key, down_stacked
|
||||
else:
|
||||
del model_state_dict[key]
|
||||
continue
|
||||
|
||||
|
||||
def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
"""
|
||||
Convert a patched HF format Llama4 model (with separated projections)
|
||||
back to the original HF format (with fused projections).
|
||||
|
||||
Args:
|
||||
model: Path to the patched HF model
|
||||
output: Path to save the converted model
|
||||
"""
|
||||
print(f"Loading model from {model}")
|
||||
from axolotl.monkeypatch.models.llama4.modeling import (
|
||||
patch_llama4_linearized_modeling,
|
||||
)
|
||||
|
||||
unpatch_llama4 = patch_llama4_linearized_modeling()
|
||||
from transformers import Llama4ForConditionalGeneration
|
||||
|
||||
model_ = Llama4ForConditionalGeneration.from_pretrained(
|
||||
model, torch_dtype=torch.bfloat16
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model)
|
||||
processor.save_pretrained(output)
|
||||
|
||||
device = model_.device.type
|
||||
if device == "cuda":
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
model_config = model_.config
|
||||
config = model_.config.get_text_config()
|
||||
|
||||
# Get key dimensions from the config
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
num_experts = config.num_local_experts
|
||||
|
||||
print(
|
||||
f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}"
|
||||
)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output, exist_ok=True)
|
||||
|
||||
# Get state dict
|
||||
state_dict = model_.state_dict()
|
||||
del model_
|
||||
|
||||
# Create a new state dict for the converted model
|
||||
converted_state_dict = {}
|
||||
|
||||
# First, copy all keys that don't need modification
|
||||
for key, value in iter_convert_patched_to_hf(state_dict, num_experts):
|
||||
converted_state_dict[key] = value
|
||||
|
||||
del state_dict
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
print("State dict converted.")
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
# Ideally re-load the model import to load the converted state dict
|
||||
# Save the converted model
|
||||
with init_empty_weights():
|
||||
unpatch_llama4()
|
||||
model_ = Llama4ForConditionalGeneration(model_config)
|
||||
|
||||
if device == "cuda":
|
||||
print("State dict loaded into model.")
|
||||
print(
|
||||
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
|
||||
)
|
||||
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
|
||||
model_.load_state_dict(converted_state_dict, strict=False, assign=True)
|
||||
print(f"Saving converted model to {output}...")
|
||||
model_.save_pretrained(output)
|
||||
|
||||
print(f"Model successfully converted and saved to {output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -330,6 +330,15 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
LOG.warning("Error raised: %s", e)
|
||||
|
||||
model.generation_config.do_sample = True
|
||||
model.config.use_cache = True
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
|
||||
|
||||
@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
kwargs["ring_attn_func"] = training_args.ring_attn_func
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
@@ -1038,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
elif self.cfg.orpo_alpha is not None:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
ring_attn_func: Optional[RingAttnFunc] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The ring-flash-attn function to use in sequence parallelism"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@ See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
|
||||
|
||||
- If you are in dev environment
|
||||
```bash
|
||||
# if you are in dev environment
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
# if you are not in dev environment
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -165,7 +165,7 @@ def cce_forward(
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -254,7 +254,7 @@ def cce_forward_multimodal(
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
@@ -263,13 +263,13 @@ def cce_forward_multimodal(
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
|
||||
@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
|
||||
- deepseek_v2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3 (partial support, no support for FLCE yet)
|
||||
- gemma3
|
||||
- granite
|
||||
- jamba
|
||||
- llama
|
||||
|
||||
@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
|
||||
)
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
|
||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||
|
||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||
_liger_rms_norm_wrapper,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
raise NotImplementedError(
|
||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||
)
|
||||
elif cfg.model_config_type == "llama4":
|
||||
from axolotl.integrations.liger.models.llama4 import (
|
||||
apply_liger_kernel_to_llama4,
|
||||
|
||||
@@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
)
|
||||
sharded_sd[param_name] = sharded_tensor
|
||||
|
||||
model.load_state_dict(sharded_sd)
|
||||
model.load_state_dict(sharded_sd, assign=True)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp_utils():
|
||||
|
||||
@@ -7,12 +7,11 @@ import torch
|
||||
import transformers
|
||||
|
||||
|
||||
def patch_flex_wrapper():
|
||||
def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
# TODO remove this patch when transformers#37285 is merged and in a release
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
|
||||
|
||||
if not (is_torch_2_6 and is_transformers_below_4_51):
|
||||
if not is_torch_2_6:
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
@@ -32,17 +31,24 @@ def patch_flex_wrapper():
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def del_singleton(cls):
|
||||
cls._instance = None
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def __init__(self):
|
||||
def __init__(self, training):
|
||||
"""
|
||||
Initialize or update the singleton instance.
|
||||
"""
|
||||
if not self._is_flex_compiled:
|
||||
self.training = None
|
||||
if not self._is_flex_compiled or training != self.training:
|
||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||
self.training = training
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
dynamic=False,
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
fullgraph=True,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
self._is_flex_compiled = True
|
||||
|
||||
@@ -50,15 +56,22 @@ def patch_flex_wrapper():
|
||||
return self._compiled_flex_attention
|
||||
|
||||
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
|
||||
setattr(
|
||||
sys.modules["transformers.integrations.flex_attention"],
|
||||
"WrappedFlexAttention",
|
||||
WrappedFlexAttention,
|
||||
)
|
||||
|
||||
|
||||
def patch_flex_make_mask():
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
|
||||
|
||||
if not (is_torch_2_6 and is_transformers_eq_4_51):
|
||||
if not is_torch_2_6:
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
)
|
||||
@@ -104,14 +117,16 @@ def patch_flex_make_mask():
|
||||
if not query_length:
|
||||
query_length = total_seq_len
|
||||
attention_mask_2d = torch.nn.functional.pad(
|
||||
attention_mask_2d, value=0, pad=(0, key_length)
|
||||
attention_mask_2d,
|
||||
value=0,
|
||||
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
|
||||
)
|
||||
device = attention_mask_2d.device
|
||||
document_ids = attention_mask_2d.clone()
|
||||
|
||||
if attention_chunk_size is not None:
|
||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
||||
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
|
||||
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
|
||||
attention_chunk_size
|
||||
)
|
||||
|
||||
@@ -138,6 +153,18 @@ def patch_flex_make_mask():
|
||||
final_mask = causal_mask & padding_mask & document_mask
|
||||
return final_mask
|
||||
|
||||
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Combines the chunk mask with the causal mask for chunked attention.
|
||||
"""
|
||||
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
|
||||
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
|
||||
return chunk_mask & causal_doc_mask
|
||||
|
||||
mask_mod_maybe_combined = (
|
||||
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
|
||||
)
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
@@ -145,10 +172,10 @@ def patch_flex_make_mask():
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
offset_kv = kv_idx + kv_offset
|
||||
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
|
||||
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
|
||||
|
||||
else:
|
||||
mask_mod = causal_mask_mod
|
||||
mask_mod = mask_mod_maybe_combined
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=mask_mod,
|
||||
B=batch_size,
|
||||
@@ -160,11 +187,16 @@ def patch_flex_make_mask():
|
||||
)
|
||||
|
||||
for n in tuple(sys.modules):
|
||||
if ".modeling_" in n and "llama4" not in n:
|
||||
if ".modeling_" in n:
|
||||
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
||||
sys.modules[n].make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
)
|
||||
setattr(
|
||||
sys.modules[n],
|
||||
"make_flex_block_causal_mask",
|
||||
patched_make_flex_block_causal_mask,
|
||||
)
|
||||
|
||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
|
||||
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Init for ring attention monkeypatch module"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .patch import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
HuggingFace flash attention adapter for basic ring attention (batch API).
|
||||
|
||||
Inspired by
|
||||
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
||||
Our implementation closely follows the structure of that module, but we've minified it
|
||||
somewhat to support only the latest versions of transformers.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,cyclic-import
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import (
|
||||
ring_flash_attn_func,
|
||||
stripe_flash_attn_func,
|
||||
zigzag_ring_flash_attn_func,
|
||||
)
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
RING_ATTN_FUNC_MAPPING = {
|
||||
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
|
||||
}
|
||||
|
||||
|
||||
def create_flash_attn_forward(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
) -> Callable:
|
||||
"""
|
||||
Create a ring flash attention forward function compatible with HuggingFace's
|
||||
interface.
|
||||
|
||||
Args:
|
||||
process_group: A PyTorch distributed process group.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
|
||||
Returns:
|
||||
A function that implements the ring flash attention forward pass with the
|
||||
signature expected by HuggingFace Transformers.
|
||||
"""
|
||||
|
||||
# transformers 4.48+
|
||||
# pylint: disable=unused-argument
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: float | None = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: torch.LongTensor | None = None,
|
||||
cu_seq_lens_k: torch.LongTensor | None = None,
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Ring Flash Attention.
|
||||
|
||||
Args:
|
||||
query_states: Tensor containing the query vectors.
|
||||
key_states: Tensor containing the key vectors.
|
||||
value_states: Tensor containing the value vectors.
|
||||
attention_mask: Not used in this implementation.
|
||||
query_length: Integer representing the length of the query sequence.
|
||||
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
||||
dropout: Float representing the dropout probability. Default is 0.0.
|
||||
position_ids: Not used in this implementation.
|
||||
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
||||
sliding_window: Optional integer defining the size of the sliding attention window.
|
||||
Default is None.
|
||||
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
||||
Default is False.
|
||||
softcap: Not used in this implementation.
|
||||
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
||||
cu_seq_lens_q: Not used in this implementation.
|
||||
cu_seq_lens_k: Not used in this implementation.
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the attention mechanism, with shape
|
||||
`[batch_size, query_length, num_heads, head_dim]`.
|
||||
"""
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
window_size = (
|
||||
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
||||
)
|
||||
|
||||
# Handle deterministic mode
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = (
|
||||
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
)
|
||||
|
||||
# Call ring flash attention function
|
||||
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=None,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=False,
|
||||
group=process_group,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
return _flash_attention_forward
|
||||
|
||||
|
||||
def substitute_hf_flash_attn(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
):
|
||||
"""
|
||||
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
||||
|
||||
Args:
|
||||
process_group: PyTorch distributed process group for communication.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
"""
|
||||
try:
|
||||
# Substitute flash attention
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_flash_attn_forward(
|
||||
process_group=process_group, ring_attn_func=ring_attn_func
|
||||
)
|
||||
|
||||
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||
new_flash_attention_forward
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The signature of the new flash attention forward function does not match the old one."
|
||||
)
|
||||
except Exception as exception:
|
||||
raise ValueError(
|
||||
f"The current transformer version {transformers.__version__} is not supported. "
|
||||
"Please use pip install -U transformers to upgrade to the latest version. "
|
||||
"If the code failed with the latest version, "
|
||||
f"please file an issue."
|
||||
) from exception
|
||||
|
||||
# Register with ALL_ATTENTION_FUNCTIONS if available
|
||||
if ALL_ATTENTION_FUNCTIONS is not None:
|
||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||
@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
BATCH_ZIGZAG = "batch_zigzag"
|
||||
BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
`batch` function.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Detailed logging of group formation
|
||||
rank = dist.get_rank()
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func in [
|
||||
RingAttnFunc.BATCH_RING,
|
||||
RingAttnFunc.BATCH_ZIGZAG,
|
||||
RingAttnFunc.BATCH_STRIPE,
|
||||
]:
|
||||
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
||||
substitute_hf_flash_attn,
|
||||
)
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(),
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
|
||||
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
"""
|
||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||
value to the substituted `ring_flash_attn`.
|
||||
|
||||
Args:
|
||||
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||
data; if not, we compute it.
|
||||
position_ids: Optional tensor of position IDs (for sample packed data).
|
||||
"""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
position_ids = batch.get("position_ids")
|
||||
if position_ids is None:
|
||||
seq_len = input_ids.shape[1]
|
||||
position_ids = torch.arange(
|
||||
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||
).unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
@@ -93,9 +93,20 @@ def patch_llama4_linearized_modeling():
|
||||
"""
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts
|
||||
modeling_llama4.Llama4TextExperts = Llama4TextExperts
|
||||
setattr(
|
||||
sys.modules["transformers.models.llama4"],
|
||||
"Llama4TextExperts",
|
||||
Llama4TextExperts,
|
||||
)
|
||||
|
||||
def unpatch():
|
||||
modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts
|
||||
setattr(
|
||||
sys.modules["transformers.models.llama4"],
|
||||
"Llama4TextExperts",
|
||||
old_lamma_4_text_experts,
|
||||
)
|
||||
|
||||
return unpatch
|
||||
|
||||
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
fix for FSDP2 evals when using torch.compile
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
|
||||
def get_evaluation_loop_code() -> str:
|
||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
eval_loop = get_evaluation_loop_code()
|
||||
eval_loop, _ = detab_code(eval_loop)
|
||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
||||
|
||||
|
||||
def patch_evaluation_loop_for_fsdp2():
|
||||
"""
|
||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
||||
"""
|
||||
|
||||
try:
|
||||
evaluation_loop = get_evaluation_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
||||
evaluation_loop
|
||||
)
|
||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
||||
return
|
||||
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
||||
)
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
"def evaluation_loop(",
|
||||
"def _fixed_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in evaluation_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -81,6 +81,11 @@ def setup_model_and_tokenizer(
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
if any(
|
||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||
for param in cfg.unfrozen_parameters
|
||||
):
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
model: Any | None = None
|
||||
padding: bool | str | PaddingStrategy = True
|
||||
max_length: int | None = None
|
||||
pad_to_multiple_of: int | None = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_degree: int = 1
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].shape[1]
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Update params for ring attention calculation
|
||||
update_ring_attn_params(batch=batch)
|
||||
# Update params for varlen ring attention calculation
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, start:end]
|
||||
for key in batch:
|
||||
if batch[key].size(1) == total_seq_len:
|
||||
if self.ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
]:
|
||||
batch[key] = (
|
||||
batch[key]
|
||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||
.contiguous()
|
||||
)
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||
|
||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||
selected_chunks = [
|
||||
chunks[self.local_rank],
|
||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||
]
|
||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||
# TODO(djsaunde): This doesn't seem to work as expected
|
||||
# Split into striped data and stack
|
||||
tensor = torch.stack(
|
||||
batch[key].split(self.local_world_size, dim=1),
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -236,6 +236,18 @@ def normalize_config(cfg):
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
if cfg.quantization:
|
||||
if cfg.quantization.backend in ["bnb"]:
|
||||
if cfg.quantization.bits == 8:
|
||||
cfg.load_in_8bit = True
|
||||
elif cfg.quantization.bits == 4:
|
||||
cfg.load_in_4bit = True
|
||||
|
||||
if cfg.quantization.backend == "gptq":
|
||||
cfg.gptq = True
|
||||
elif cfg.quantization.backend == "hqq":
|
||||
cfg.hqq = True
|
||||
|
||||
|
||||
def normalize_cfg_datasets(cfg):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
iter_ds = load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
)
|
||||
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
|
||||
# other ranks, we just need to present a fake dataset
|
||||
if (
|
||||
cfg.accelerator_config
|
||||
and cfg.accelerator_config.dispatch_batches
|
||||
and not is_local_main_process()
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
||||
f.write("text\n")
|
||||
f.write("lorem ipsum dolor sit amet\n")
|
||||
# rewind the file pointer to the beginning so we can read it again
|
||||
f.seek(0)
|
||||
iter_ds = load_dataset(
|
||||
"csv", data_files=f.name, split="train", streaming=True
|
||||
)
|
||||
else:
|
||||
if is_local_main_process():
|
||||
iter_ds = load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
)
|
||||
|
||||
if skip:
|
||||
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||
iter_ds = iter_ds.skip(skip)
|
||||
@@ -332,16 +351,23 @@ def load_tokenized_prepared_datasets(
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
if isinstance(dataset, IterableDataset):
|
||||
num_workers = cfg.dataset_processes
|
||||
|
||||
def gen_from_iter_ds(_ds, _=None):
|
||||
yield from _ds
|
||||
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
|
||||
"""Generator function to correctly splice the dataset for each worker"""
|
||||
for i, item in enumerate(_ds):
|
||||
if i % num_workers[0] == worker_id[0]:
|
||||
yield item
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(gen_from_iter_ds, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=cfg.dataset_processes,
|
||||
num_proc=num_workers,
|
||||
split=split,
|
||||
gen_kwargs={"_": list(range(cfg.dataset_processes))},
|
||||
gen_kwargs={
|
||||
"worker_id": list(range(num_workers)),
|
||||
"num_workers": [num_workers] * num_workers,
|
||||
},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
@@ -184,7 +185,7 @@ class LayerNamePattern:
|
||||
"""
|
||||
self.raw_pattern = pattern
|
||||
name_pattern, self.range = self._parse_pattern(pattern)
|
||||
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
||||
self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern))
|
||||
|
||||
def match(self, name: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -36,6 +36,7 @@ from transformers import (
|
||||
BitsAndBytesConfig,
|
||||
Gemma3ForConditionalGeneration,
|
||||
GPTQConfig,
|
||||
HqqConfig,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
@@ -542,6 +543,17 @@ class ModelLoader:
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
patch_flex_make_mask()
|
||||
|
||||
# patch gemma3 conditional generation forward before loading plugins
|
||||
# as it could be overridden by plugins
|
||||
if self.cfg.model_config_type == "llama4":
|
||||
@@ -644,6 +656,7 @@ class ModelLoader:
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
ring_attn_func=self.cfg.ring_attn_func,
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
@@ -821,6 +834,13 @@ class ModelLoader:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
def set_quantization_config(self) -> None:
|
||||
if (
|
||||
(not self.cfg.quantization)
|
||||
and (not self.cfg.load_in_8bit)
|
||||
and (not self.cfg.load_in_4bit)
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
return
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
@@ -842,21 +862,21 @@ class ModelLoader:
|
||||
and hasattr(self.model_config, "quantization_config")
|
||||
and self.model_config.quantization_config["quant_method"]
|
||||
in ["gptq", "awq", "bitsandbytes"]
|
||||
and not self.cfg.hqq
|
||||
):
|
||||
if self.model_config.quantization_config["quant_method"] == "gptq":
|
||||
self.model_kwargs["quantization_config"] = GPTQConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.model_config.quantization_config["quant_method"] == "awq":
|
||||
self.model_kwargs["quantization_config"] = AwqConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif (
|
||||
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
|
||||
):
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
quant_config_class_dict = {
|
||||
"gptq": GPTQConfig,
|
||||
"awq": AwqConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
}
|
||||
|
||||
quant_config_class = quant_config_class_dict[
|
||||
self.model_config.quantization_config["quant_method"]
|
||||
]
|
||||
self.model_kwargs["quantization_config"] = quant_config_class(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
@@ -874,8 +894,8 @@ class ModelLoader:
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
|
||||
if self.cfg.bnb_config_kwargs:
|
||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
|
||||
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
||||
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
@@ -891,6 +911,13 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
if self.cfg.hqq:
|
||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||
|
||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||
**get_hqq_quant_config_kwargs(self.cfg)
|
||||
)
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
@@ -905,13 +932,6 @@ class ModelLoader:
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
)
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
patch_flex_wrapper()
|
||||
patch_flex_make_mask()
|
||||
|
||||
elif self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
@@ -1031,6 +1051,12 @@ class ModelLoader:
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
if self.cfg.hqq and torch.cuda.device_count() < 2:
|
||||
# for some reason on single gpu, we need to set device_map to auto/cuda
|
||||
# otherwise you run into tensors on two devices error during training
|
||||
# Doesn't affect multi-gpu tho
|
||||
|
||||
self.model_kwargs["device_map"] = "auto"
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
@@ -1115,7 +1141,7 @@ class ModelLoader:
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def ajust_model_config(self) -> None:
|
||||
def adjust_model_config(self) -> None:
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
@@ -1185,7 +1211,7 @@ class ModelLoader:
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
self.model = prepare_model_for_kbit_training(
|
||||
@@ -1275,7 +1301,7 @@ class ModelLoader:
|
||||
else:
|
||||
self.model.tie_weights()
|
||||
|
||||
self.ajust_model_config()
|
||||
self.adjust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
@@ -1455,7 +1481,16 @@ def load_llama_adapter(model, cfg):
|
||||
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
from hqq.core.peft import HQQLinearLoRA
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
cls = (
|
||||
bnb.nn.Linear4bit,
|
||||
bnb.nn.Linear8bitLt,
|
||||
torch.nn.Linear,
|
||||
HQQLinear,
|
||||
HQQLinearLoRA,
|
||||
)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
|
||||
@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
|
||||
self.max_lr = max_lr
|
||||
self.total_steps = total_steps
|
||||
self.num_warmup_steps = num_warmup_steps
|
||||
self.last_step = last_step - 1
|
||||
self.last_step = max(last_step - 1, 0)
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
|
||||
@@ -225,6 +225,7 @@ class AxolotlInputConfig(
|
||||
sdp_attention: bool | None = None
|
||||
s2_attention: bool | None = None
|
||||
flex_attention: bool | None = None
|
||||
flex_attn_compile_kwargs: dict[str, Any] | None = None
|
||||
flash_attention: bool | None = None
|
||||
flash_attn_cross_entropy: bool | None = None
|
||||
flash_attn_rms_norm: bool | None = None
|
||||
@@ -258,6 +259,7 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: str | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -658,6 +660,7 @@ class AxolotlInputConfig(
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
and data.get("eval_strategy") != "no"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
@@ -1146,7 +1149,7 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@field_validator("sequence_parallel_degree", mode="after")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
@@ -1158,9 +1161,12 @@ class AxolotlInputConfig(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if not info.data["micro_batch_size"] == 1:
|
||||
if (
|
||||
info.data.get("sample_packing")
|
||||
and not info.data["micro_batch_size"] == 1
|
||||
):
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 "
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
@@ -1187,6 +1193,34 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("ring_attn_func", mode="after")
|
||||
@classmethod
|
||||
def check_ring_attn_func(cls, value, info):
|
||||
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
||||
return value
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
if value is not None:
|
||||
# Set the ring attention function if passed in config
|
||||
valid_funcs = list(RingAttnFunc)
|
||||
if value in valid_funcs:
|
||||
value = RingAttnFunc(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
||||
)
|
||||
else:
|
||||
# Default ring attention function selection
|
||||
sample_packing = info.data.get("sample_packing")
|
||||
value = (
|
||||
RingAttnFunc.VARLEN_LLAMA3
|
||||
if sample_packing
|
||||
else RingAttnFunc.BATCH_RING
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
@@ -1276,11 +1310,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
is_fsdp2 = (
|
||||
data.get("fsdp_config") is not None
|
||||
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
|
||||
)
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Pydantic models for PEFT-related configuration"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from axolotl.utils.schemas.quant import QuantizationConfig
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
|
||||
class LoraConfig(BaseModel):
|
||||
"""Peft / LoRA configuration subset"""
|
||||
|
||||
load_in_8bit: bool | None = Field(default=False)
|
||||
load_in_4bit: bool | None = Field(default=False)
|
||||
quantization: QuantizationConfig | None = None
|
||||
load_in_4bit: bool | None = None # for internal use
|
||||
load_in_8bit: bool | None = None # for internal use
|
||||
hqq: bool | None = None # for internal use
|
||||
gptq: bool | None = None # for internal use
|
||||
|
||||
adapter: str | None = None
|
||||
lora_model_dir: str | None = None
|
||||
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
lora_on_cpu: bool | None = None
|
||||
gptq: bool | None = None
|
||||
bnb_config_kwargs: dict[str, Any] | None = None
|
||||
|
||||
loraplus_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
|
||||
if (
|
||||
not data.get("adapter")
|
||||
and not data.get("inference")
|
||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||
and (data.get("quantization"))
|
||||
):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
"Quantization is not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off Quantization."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -86,25 +87,26 @@ class LoraConfig(BaseModel):
|
||||
def validate_qlora(self):
|
||||
if self.adapter == "qlora":
|
||||
if self.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if self.load_in_8bit:
|
||||
if self.quantization.bits == 8 or self.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if self.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't merge qlora if using gptq")
|
||||
|
||||
if self.load_in_4bit:
|
||||
if self.quantization.bits == 4 or self.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if self.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
if self.quantization:
|
||||
if self.quantization.bits == 8 or self.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if self.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't load qlora if using gptq")
|
||||
|
||||
if not self.quantization.bits == 4 or self.load_in_4bit:
|
||||
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||
|
||||
if not self.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
return self
|
||||
|
||||
@field_validator("loraplus_lr_embedding")
|
||||
@@ -121,6 +123,24 @@ class LoraConfig(BaseModel):
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_hqq(cls, data):
|
||||
if (
|
||||
data.get("quantization")
|
||||
and data.get("quantization").get("backend") == "hqq"
|
||||
):
|
||||
if not data.get("quantization").get("hqq_config"):
|
||||
raise ValueError(
|
||||
"If using HQQ, must set `hqq_config` under `quantization`"
|
||||
)
|
||||
|
||||
if data.get("load_in_4bit") or data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
93
src/axolotl/utils/schemas/quant.py
Normal file
93
src/axolotl/utils/schemas/quant.py
Normal file
@@ -0,0 +1,93 @@
|
||||
""" "
|
||||
Takes care of quantization configuration
|
||||
"""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from annotated_types import MinLen
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class HQQConfig(BaseModel):
|
||||
"""HQQ configuration subset"""
|
||||
|
||||
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
|
||||
},
|
||||
)
|
||||
|
||||
group_size: int = Field(default=64)
|
||||
target_modules: list[str] | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class QuantizationConfig(BaseModel):
|
||||
"""Over all Quantization configuration subset"""
|
||||
|
||||
# We will use this class as base future refactoring of all quantization configs
|
||||
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
||||
bits: Literal[8, 4, 3, 2, 1] | None = None
|
||||
bnb_config_kwargs: dict[str, Any] | None = None
|
||||
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_hqq_config(cls, data):
|
||||
if data.get("backend") == "hqq" and not data.get("hqq_config"):
|
||||
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
|
||||
|
||||
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
||||
for hqq_config in data.get("hqq_config"):
|
||||
if hqq_config.get("target_modules") is None:
|
||||
raise ValueError(
|
||||
"For list of hqq configs, `target_modules` must be specified for each"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_hqq_quant_config_kwargs(cfg):
|
||||
|
||||
# If no target module is specified, then target the whole model
|
||||
if not isinstance(cfg.quantization.hqq_config, list):
|
||||
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
|
||||
|
||||
if (
|
||||
len(cfg.quantization.hqq_config) == 1
|
||||
and cfg.quantization.hqq_config[0].target_modules is None
|
||||
):
|
||||
|
||||
nbits = (
|
||||
cfg.quantization.hqq_config[0].nbits
|
||||
if cfg.quantization.hqq_config[0].nbits is not None
|
||||
else cfg.quantization.bits
|
||||
)
|
||||
|
||||
return {
|
||||
"nbits": nbits,
|
||||
"group_size": cfg.quantization.hqq_config[0].group_size,
|
||||
}
|
||||
|
||||
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||
for hqq_config in cfg.quantization.hqq_config:
|
||||
nbits = (
|
||||
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
|
||||
)
|
||||
|
||||
target_modules = hqq_config.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
|
||||
for module in target_modules:
|
||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||
"nbits": nbits,
|
||||
"group_size": hqq_config.group_size,
|
||||
}
|
||||
|
||||
return hqq_quant_config_kwargs
|
||||
@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
@@ -235,7 +236,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.model_config_type in ["mamba", "gemma3"]:
|
||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||
if drop_attn_mask:
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
@@ -625,6 +627,12 @@ def setup_trainer(
|
||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||
on the provided parameters.
|
||||
"""
|
||||
if (
|
||||
cfg.torch_compile
|
||||
and cfg.fsdp_config
|
||||
and str(cfg.fsdp_config.fsdp_version) == "2"
|
||||
):
|
||||
patch_evaluation_loop_for_fsdp2()
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
trainer_builder.model_ref = model_ref
|
||||
|
||||
@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
|
||||
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_evolkit_kd_sample_dataset():
|
||||
# download the dataset
|
||||
snapshot_download_w_retry(
|
||||
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_deepseek_model_fixture():
|
||||
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama33_70b_model_fixture():
|
||||
# download the tokenizer only
|
||||
snapshot_download_w_retry(
|
||||
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
repo_type="model",
|
||||
allow_patterns=["*token*", "config.json"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama_1b_model_fixture():
|
||||
# download the tokenizer only
|
||||
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_llama32_1b_model_fixture():
|
||||
snapshot_download_w_retry(
|
||||
"osllmai-community/Llama-3.2-1B",
|
||||
repo_type="model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
|
||||
0
tests/e2e/multigpu/patched/__init__.py
Normal file
0
tests/e2e/multigpu/patched/__init__.py
Normal file
@@ -3,13 +3,14 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
|
||||
def test_sequence_parallel_training(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
def _run_sequence_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sample_packing": sample_packing,
|
||||
"eval_sample_packing": sample_packing,
|
||||
"pad_to_sequence_len": pad_to_sequence_len,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 8,
|
||||
"micro_batch_size": 1,
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
|
||||
[
|
||||
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None), # defaults to batch_ring ring_attn_func
|
||||
(False, 2, True, "batch_zigzag"),
|
||||
# (False, 2, False), # not yet working
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
# "no sample_packing, pad_to_sequence_len", # not yet working
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
micro_batch_size,
|
||||
pad_to_sequence_len,
|
||||
ring_attn_func,
|
||||
):
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
pad_to_sequence_len=pad_to_sequence_len,
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
@@ -0,0 +1,2 @@
|
||||
# Tests under this directory should get run "solo" on their own as they
|
||||
# seem to cause issues when run in the same batch as other tests.
|
||||
|
||||
@@ -49,18 +49,20 @@ class TestPackedFlex:
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"max_steps": 2,
|
||||
"use_tensorboard": True,
|
||||
"save_strategy": "no",
|
||||
}
|
||||
|
||||
@@ -177,6 +177,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
@@ -264,6 +265,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process_id = start_vllm(
|
||||
cfg.base_model,
|
||||
|
||||
@@ -30,8 +30,10 @@ class TestMultiGPUEval:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
@@ -99,8 +101,10 @@ class TestMultiGPUEval:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"strict": False,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
|
||||
@@ -171,7 +171,10 @@ class TestMultiGPULlama:
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
@@ -249,7 +252,10 @@ class TestMultiGPULlama:
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
@@ -548,7 +554,10 @@ class TestMultiGPULlama:
|
||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||
"adapter": "qlora",
|
||||
"mean_resizing_embeddings": True,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
@@ -621,12 +630,6 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
# TODO: remove skip once deepspeed regression is fixed
|
||||
# see https://github.com/huggingface/transformers/pull/37324
|
||||
@pytest.mark.skipif(
|
||||
transformers_version_eq("4.51.0"),
|
||||
reason="zero3 is not supported with transformers==4.51.0",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
@@ -654,7 +657,10 @@ class TestMultiGPULlama:
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
@@ -728,7 +734,10 @@ class TestMultiGPULlama:
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
@@ -802,7 +811,10 @@ class TestMultiGPULlama:
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
else:
|
||||
adapter = {}
|
||||
|
||||
@@ -28,7 +28,10 @@ class TestMultiGPUQwen2:
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"sequence_len": 2048,
|
||||
|
||||
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
|
||||
"""Test LoRA kernel patches across different model architectures."""
|
||||
# Load model with appropriate dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
|
||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
|
||||
)
|
||||
|
||||
# Apply LoRA configuration
|
||||
|
||||
@@ -32,7 +32,10 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
|
||||
@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
},
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"gptq": True,
|
||||
|
||||
@@ -33,7 +33,10 @@ class TestMixtral(unittest.TestCase):
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
|
||||
@@ -46,8 +46,9 @@ class TestResumeLlama:
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
|
||||
@@ -73,7 +73,10 @@ class TestRingAttention:
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
RingAttnFunc,
|
||||
register_ring_attn,
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
@@ -82,7 +85,11 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
|
||||
@@ -34,7 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
|
||||
@@ -35,7 +35,10 @@ class TestMixtral(unittest.TestCase):
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
"lora_r": 4,
|
||||
"lora_alpha": 8,
|
||||
@@ -91,7 +94,10 @@ class TestMixtral(unittest.TestCase):
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 1024,
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
"lora_r": 4,
|
||||
"lora_alpha": 8,
|
||||
|
||||
@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
|
||||
141
tests/e2e/test_quantization.py
Normal file
141
tests/e2e/test_quantization.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
E2E tests for training with quantized model
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestHQQ(unittest.TestCase):
|
||||
"""
|
||||
Test cases for training of HQQ-quantized llama models"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_hqq_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"use_hqq": True,
|
||||
"hqq_config": [
|
||||
{
|
||||
"nbits": 8,
|
||||
"group_size": 64,
|
||||
}
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_hqq_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"use_hqq": True,
|
||||
"hqq_config": [
|
||||
{
|
||||
"nbits": 4,
|
||||
"group_size": 64,
|
||||
}
|
||||
],
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
)
|
||||
@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
# "load_in_4bit": True
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": "",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": None,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
"quantization": {
|
||||
"bits": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||
with pytest.raises(ValueError, match=r".*bits <= 4*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
match=r"Quantization is not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
)
|
||||
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"bits": 8,
|
||||
},
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -21,8 +21,10 @@ class TestModelsUtils:
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"model_type": "LlamaForCausalLM",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"load_in_8bit": True,
|
||||
"load_in_4bit": False,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
"adapter": "lora",
|
||||
"flash_attention": False,
|
||||
"sample_packing": True,
|
||||
|
||||
Reference in New Issue
Block a user