Compare commits

...

12 Commits

Author SHA1 Message Date
Wing Lian
a0670abc94 add output for train loss in assertian err 2025-04-18 08:11:11 -07:00
Wing Lian
08f287b57f swap llama tests for 7m param model 2025-04-17 09:52:35 -07:00
Wing Lian
b4c7d9c29d fix perplexity scores 2025-04-17 07:58:53 -07:00
Wing Lian
d2637fb01d first pass at modifying tests to use llama-7m 2025-04-16 21:14:04 -07:00
NanoCode012
9da730d6a4 fix(doc): cut cross entropy installation instructions broken in qmd (#2532) 2025-04-16 15:02:51 -07:00
NanoCode012
32637fad00 fix: preprocess yielding whole dataset to each worker (#2503) [skip ci] 2025-04-16 15:02:35 -07:00
Dan Saunders
f776f889a1 adding codecov reporting (#2372) [skip ci]
* adding codecov reporting

* update codecov-action to v5

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-16 15:02:17 -07:00
Wing Lian
69eda209a6 re-enable DS zero3 ci with updated transformers (#2533) 2025-04-16 14:48:40 -07:00
Dan Saunders
b8c633aa97 batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* removing pad_to_sequence_len=False for now

* fix

* updating docs to include batch SP

* review comments

* fixes for batch API funcs, simplify

* fixes

* fix

* updates

* add batch_zigzag smoke test
2025-04-16 13:50:48 -04:00
NanoCode012
682a9cf79b Fix: add delinearization and make qlora work with fsdp2 (#2515)
* fixes for delinearization, and make qlora work with fsdp2

* Add back mistakenly removed lm_eval

* typo [skip ci]

* patch evals for torch.compile + fsdp2

* also check torch_compile w fsdp2

* lots of fixes for flex attn with llama4

* fix patch check and patch llama4 too

* attempt to make the patches stick

* use transformers 4.51.2

* update configs and README for llama4

* remove torch.compile for CI test

* cleanup any existing singletons

* set singleton cache to None instead of deleting

* use importlib reload with monkeypatch

* don't worry about transformers version, mark inputs with grads, fix regex

* make sure embeds aren't on cpu

* logging and mem improvements

* vllm version and add to docker, make sure to save processor on conversion

* fix ambiguous tensor bool check

* fix vllm to not use v1, upgrade hf transformers

* fix tests

* make flex_attn_compile_kwargs configurable, since this depends on model params

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-04-15 23:31:39 -07:00
NanoCode012
271b24cccc feat: update cce to latest (#2521) 2025-04-15 22:17:10 -07:00
Wing Lian
198d775d6d make sure the all of the model is on the same device, so this test will pass on multigpu (#2524) [skip ci] 2025-04-15 22:15:42 -07:00
62 changed files with 1246 additions and 162 deletions

14
.coveragerc Normal file
View File

@@ -0,0 +1,14 @@
[run]
source = axolotl
omit =
*/tests/*
setup.py
[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
pass
raise ImportError

View File

@@ -29,7 +29,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras: vllm
is_latest: true is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -102,9 +102,16 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |

View File

@@ -9,6 +9,7 @@
<p align="center"> <p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License"> <img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests"> <img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a> <a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/> <br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a> <a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>

View File

@@ -3,10 +3,59 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/ # Run unit tests with initial coverage report
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure pytest -v --durations=10 -n8 \
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched --ignore=tests/e2e/ \
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ --ignore=tests/patched/ \
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ --ignore=tests/cli \
pytest -v --durations=10 /workspace/axolotl/tests/cli /workspace/axolotl/tests/ \
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/ --cov=axolotl \
--cov-report=xml:coverage.xml
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/patched/lora_kernels \
--cov=axolotl \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \
--cov-append
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/solo/ \
--cov=axolotl \
--cov-append
# Run integration tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/integrations/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 /workspace/axolotl/tests/cli \
--cov=axolotl \
--cov-append
# Run remaining e2e tests with coverage append and final report
pytest -v --durations=10 \
--ignore=tests/e2e/solo/ \
--ignore=tests/e2e/patched/ \
--ignore=tests/e2e/multigpu/ \
--ignore=tests/e2e/integrations/ \
--ignore=tests/cli \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:coverage.xml
# Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -4,3 +4,22 @@ set -e
# only run one test at a time so as not to OOM the GPU # only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

51
codecov.yml Normal file
View File

@@ -0,0 +1,51 @@
codecov:
require_ci_to_pass: yes
coverage:
precision: 2
round: down
range: "70...100"
status:
project:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
patch:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
parsers:
gcov:
branch_detection:
conditional: yes
loop: yes
method: no
macro: no
comment:
layout: "reach,diff,flags,files,footer"
behavior: default
require_changes: no
require_base: no
require_head: yes

View File

@@ -693,6 +693,9 @@ sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model. # Must evenly divide the number of KV heads in your model.
heads_k_stride: 1 heads_k_stride: 1
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
# in the sample packing case, and "batch_ring" in the non-sample packing case.
ring_attn_func:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:

View File

@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
sequence_parallel_degree: 4 # Split sequences across 4 GPUs sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster. # Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1 heads_k_stride: 1
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
``` ```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

View File

@@ -1,16 +1,28 @@
# Llama 4 by Meta AI # Llama 4 by Meta AI
## Flash Attention vs Flex Attention
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
## Available Examples ## Available Examples
### Llama 4 Scout 17Bx16Experts (109B) ### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd) Flex Attention
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
[//]: # (Flash Attention &#40;Do not use&#41;)
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1]&#40;./scout-vision-qlora-fsdp.yaml&#41;)
[//]: # (- [Text Single GPU &#40;H100&#41; QLoRA]&#40;./scout-qlora-single-h100.yaml&#41;)
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1]&#40;./scout-qlora-fsdp1.yaml&#41;)
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
### Llama 4 Maverick 17Bx128Experts (400B) ### Llama 4 Maverick 17Bx128Experts (400B)
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml) Coming Soon
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)

View File

@@ -0,0 +1,86 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -0,0 +1,85 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head # needed if modifying vocabulary
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
torch_compile: true
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -0,0 +1,89 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
sequence_len: 4096
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -1,6 +1,6 @@
pre-commit
black black
mypy mypy
pre-commit
types-requests types-requests
quartodoc quartodoc
jupyter jupyter

View File

@@ -1,5 +1,7 @@
codecov
pytest pytest
pytest-xdist pytest-cov
pytest-retry pytest-retry
pytest-sugar pytest-sugar
pytest-xdist
tbparse tbparse

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.6
packaging==23.2 packaging==23.2
peft==0.15.1 peft==0.15.1
transformers==4.51.1 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0

View File

@@ -25,5 +25,5 @@ if cce_spec:
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"' + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
) )

View File

@@ -67,7 +67,7 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 6): if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2") _install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:

View File

@@ -0,0 +1,156 @@
"""
CLI tool to delinearize quantized/Linearized Llama-4 models.
"""
import os
from pathlib import Path
from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:
keys = list(model_state_dict.keys())
for key in keys:
if ".feed_forward.experts." not in key:
yield key, model_state_dict[key]
if ".feed_forward.experts.gate_projs" in key:
# gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up
continue
if ".feed_forward.experts.up_projs" in key:
if ".feed_forward.experts.up_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".up_projs.0.")[0]
key = f"{prefix}.gate_up_proj"
# grab all the up_projs and gate_projs across all experts
gate_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.gate_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
up_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.up_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)
del gate_stacked, up_stacked
yield key, gate_up_proj
else:
del model_state_dict[key]
continue
if ".feed_forward.experts.down_projs" in key:
if ".feed_forward.experts.down_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".down_projs.0.")[0]
key = f"{prefix}.down_proj"
# grab all the down_projs across all experts
down_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.down_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
yield key, down_stacked
else:
del model_state_dict[key]
continue
def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
"""
Convert a patched HF format Llama4 model (with separated projections)
back to the original HF format (with fused projections).
Args:
model: Path to the patched HF model
output: Path to save the converted model
"""
print(f"Loading model from {model}")
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output)
device = model_.device.type
if device == "cuda":
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_config = model_.config
config = model_.config.get_text_config()
# Get key dimensions from the config
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
num_experts = config.num_local_experts
print(
f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}"
)
# Create output directory if it doesn't exist
os.makedirs(output, exist_ok=True)
# Get state dict
state_dict = model_.state_dict()
del model_
# Create a new state dict for the converted model
converted_state_dict = {}
# First, copy all keys that don't need modification
for key, value in iter_convert_patched_to_hf(state_dict, num_experts):
converted_state_dict[key] = value
del state_dict
if device == "cuda":
torch.cuda.empty_cache()
print("State dict converted.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
# Ideally re-load the model import to load the converted state dict
# Save the converted model
with init_empty_weights():
unpatch_llama4()
model_ = Llama4ForConditionalGeneration(model_config)
if device == "cuda":
print("State dict loaded into model.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_.load_state_dict(converted_state_dict, strict=False, assign=True)
print(f"Saving converted model to {output}...")
model_.save_pretrained(output)
print(f"Model successfully converted and saved to {output}")
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -330,6 +330,15 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args) do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
cli.add_command(lm_eval) cli.add_command(lm_eval)

View File

@@ -40,6 +40,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.warning("Error raised: %s", e) LOG.warning("Error raised: %s", e)
model.generation_config.do_sample = True model.generation_config.do_sample = True
model.config.use_cache = True
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")

View File

@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sequence_parallel_degree"] = ( training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree self.cfg.sequence_parallel_degree
) )
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq): if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator( return collator(
*collator_args, *collator_args,

View File

@@ -9,6 +9,8 @@ from PIL.Image import Resampling
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingMixins:
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
default=1, default=1,
metadata={"help": "The number of workers to use in sequence parallelism"}, metadata={"help": "The number of workers to use in sequence parallelism"},
) )
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
# multi-modal section # multi-modal section

View File

@@ -12,12 +12,14 @@ See https://github.com/apple/ml-cross-entropy
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already. Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
- If you are in dev environment
```bash ```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
```
# if you are not in dev environment - If you are installing from pip
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c" ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
``` ```
## Usage ## Usage

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
) )

View File

@@ -165,7 +165,7 @@ def cce_forward(
) )
def cce_forward_multimodal( def cce_forward_multimodal(
self, self,
input_ids: torch.LongTensor | None = None, input_ids: torch.LongTensor | None = None, # type: ignore
pixel_values: torch.FloatTensor | None = None, pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -254,7 +254,7 @@ def cce_forward_multimodal(
) )
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
if pixel_values is not None: if pixel_values is not None:
image_features = self.get_image_features( image_features = self.get_image_features(
@@ -263,13 +263,13 @@ def cce_forward_multimodal(
vision_feature_select_strategy=vision_feature_select_strategy, vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes, image_sizes=image_sizes,
) )
original_inputs_embeds_shape = inputs_embeds.shape original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
vision_flat = image_features.view(-1, image_features.size(-1)) vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat) projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device) final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
final_mask_1d = final_mask[..., 0].reshape(-1) final_mask_1d = final_mask[..., 0].reshape(-1)

View File

@@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
) )
sharded_sd[param_name] = sharded_tensor sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd) model.load_state_dict(sharded_sd, assign=True)
def patch_accelerate_fsdp_utils(): def patch_accelerate_fsdp_utils():

View File

@@ -7,12 +7,11 @@ import torch
import transformers import transformers
def patch_flex_wrapper(): def patch_flex_wrapper(**flex_attn_compile_kwargs):
# TODO remove this patch when transformers#37285 is merged and in a release # TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6") is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if not (is_torch_2_6 and is_transformers_below_4_51): if not is_torch_2_6:
return return
from torch.nn.attention.flex_attention import flex_attention from torch.nn.attention.flex_attention import flex_attention
@@ -32,17 +31,24 @@ def patch_flex_wrapper():
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
@classmethod
def del_singleton(cls):
cls._instance = None
@torch.compiler.disable(recursive=False) @torch.compiler.disable(recursive=False)
def __init__(self): def __init__(self, training):
""" """
Initialize or update the singleton instance. Initialize or update the singleton instance.
""" """
if not self._is_flex_compiled: self.training = None
if not self._is_flex_compiled or training != self.training:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
self._compiled_flex_attention = torch.compile( self._compiled_flex_attention = torch.compile(
flex_attention, flex_attention,
dynamic=False, **flex_attn_compile_kwargs,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
) )
self._is_flex_compiled = True self._is_flex_compiled = True
@@ -50,15 +56,22 @@ def patch_flex_wrapper():
return self._compiled_flex_attention return self._compiled_flex_attention
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
setattr(
sys.modules["transformers.integrations.flex_attention"],
"WrappedFlexAttention",
WrappedFlexAttention,
)
def patch_flex_make_mask(): def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6") is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not (is_torch_2_6 and is_transformers_eq_4_51): if not is_torch_2_6:
return return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import ( from torch.nn.attention.flex_attention import (
BlockMask, BlockMask,
) )
@@ -104,14 +117,16 @@ def patch_flex_make_mask():
if not query_length: if not query_length:
query_length = total_seq_len query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad( attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d, value=0, pad=(0, key_length) attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
) )
device = attention_mask_2d.device device = attention_mask_2d.device
document_ids = attention_mask_2d.clone() document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None: if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // ( chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size attention_chunk_size
) )
@@ -138,6 +153,18 @@ def patch_flex_make_mask():
final_mask = causal_mask & padding_mask & document_mask final_mask = causal_mask & padding_mask & document_mask
return final_mask return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None: if offsets is not None:
q_offset = offsets[0] q_offset = offsets[0]
kv_offset = offsets[1] kv_offset = offsets[1]
@@ -145,10 +172,10 @@ def patch_flex_make_mask():
def mask_mod(batch_idx, head_idx, q_idx, kv_idx): def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else: else:
mask_mod = causal_mask_mod mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex( return create_block_causal_mask_flex(
mask_mod=mask_mod, mask_mod=mask_mod,
B=batch_size, B=batch_size,
@@ -160,11 +187,16 @@ def patch_flex_make_mask():
) )
for n in tuple(sys.modules): for n in tuple(sys.modules):
if ".modeling_" in n and "llama4" not in n: if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"): if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = ( sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask patched_make_flex_block_causal_mask
) )
setattr(
sys.modules[n],
"make_flex_block_causal_mask",
patched_make_flex_block_causal_mask,
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = ( transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask patched_make_flex_block_causal_mask

View File

@@ -0,0 +1,12 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -0,0 +1,192 @@
"""
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
Our implementation closely follows the structure of that module, but we've minified it
somewhat to support only the latest versions of transformers.
"""
# pylint: disable=protected-access,cyclic-import
import os
from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
}
def create_flash_attn_forward(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
Create a ring flash attention forward function compatible with HuggingFace's
interface.
Args:
process_group: A PyTorch distributed process group.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
Returns:
A function that implements the ring flash attention forward pass with the
signature expected by HuggingFace Transformers.
"""
# transformers 4.48+
# pylint: disable=unused-argument
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: torch.Tensor | None = None,
softmax_scale: float | None = None,
sliding_window: int | None = None,
use_top_left_mask: bool = False,
softcap: float | None = None,
deterministic: bool = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
**kwargs,
):
"""
Calls the forward method of Ring Flash Attention.
Args:
query_states: Tensor containing the query vectors.
key_states: Tensor containing the key vectors.
value_states: Tensor containing the value vectors.
attention_mask: Not used in this implementation.
query_length: Integer representing the length of the query sequence.
is_causal: Boolean indicating whether to apply a causal mask to the attention.
dropout: Float representing the dropout probability. Default is 0.0.
position_ids: Not used in this implementation.
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
sliding_window: Optional integer defining the size of the sliding attention window.
Default is None.
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
Default is False.
softcap: Not used in this implementation.
deterministic: Optional boolean to enforce deterministic computation. Default is None.
cu_seq_lens_q: Not used in this implementation.
cu_seq_lens_k: Not used in this implementation.
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
torch.Tensor: The output of the attention mechanism, with shape
`[batch_size, query_length, num_heads, head_dim]`.
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
window_size = (
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
)
# Handle deterministic mode
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = (
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
# Call ring flash attention function
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
query_states,
key_states,
value_states,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
group=process_group,
)
return attn_output
return _flash_attention_forward
def substitute_hf_flash_attn(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
):
"""
Substitute HuggingFace's flash attention implementation with ring-based implementation.
Args:
process_group: PyTorch distributed process group for communication.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
"""
try:
# Substitute flash attention
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
process_group=process_group, ring_attn_func=ring_attn_func
)
if check_params(old_flash_attention_forward, new_flash_attention_forward):
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
else:
raise ValueError(
"The signature of the new flash attention forward function does not match the old one."
)
except Exception as exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"Please use pip install -U transformers to upgrade to the latest version. "
"If the code failed with the latest version, "
f"please file an issue."
) from exception
# Register with ALL_ATTENTION_FUNCTIONS if available
if ALL_ATTENTION_FUNCTIONS is not None:
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

View File

@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
their sequence parallel version of Flash Attention 2. their sequence parallel version of Flash Attention 2.
""" """
from enum import Enum
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger(__name__)
RING_ATTN_GROUP = None RING_ATTN_GROUP = None
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
sequence_parallel_degree: Sequence parallelism factor. sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`. through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
""" """
if get_ring_attn_group() is not None: if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...") LOG.info("Ring attention already registered, exiting early...")
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"each sequence will be processed across {sequence_parallel_degree} GPUs" f"each sequence will be processed across {sequence_parallel_degree} GPUs"
) )
rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, ( assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) " f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})" f"must be less than or equal to world_size ({world_size})"
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
f"must evenly divide world_size ({world_size})" f"must evenly divide world_size ({world_size})"
) )
# Detailed logging of group formation # Assign ranks to sequence parallel groups
rank = dist.get_rank()
group_assignments = {} group_assignments = {}
for i in range(world_size // sequence_parallel_degree): for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list( ring_attn_ranks = list(
range( range(
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
if rank == 0: if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}") LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None: if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
heads_k_stride = 1 from ring_flash_attn import substitute_hf_flash_attn
from ring_flash_attn import substitute_hf_flash_attn substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,
RingAttnFunc.BATCH_STRIPE,
]:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn( substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride process_group=get_ring_attn_group(),
) ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(batch: dict[str, torch.Tensor]): def update_ring_attn_params(position_ids: torch.Tensor | None):
""" """
Calculate the cumulative sequence lengths for the current forward pass and pass the Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`. value to the substituted `ring_flash_attn`.
Args: Args:
batch: A dictionary with a batch of data. May or may not contain `position_ids` position_ids: Optional tensor of position IDs (for sample packed data).
data; if not, we compute it.
""" """
from ring_flash_attn import update_ring_flash_attn_params from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -93,9 +93,20 @@ def patch_llama4_linearized_modeling():
""" """
from transformers.models.llama4 import modeling_llama4 from transformers.models.llama4 import modeling_llama4
old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts
modeling_llama4.Llama4TextExperts = Llama4TextExperts modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr( setattr(
sys.modules["transformers.models.llama4"], sys.modules["transformers.models.llama4"],
"Llama4TextExperts", "Llama4TextExperts",
Llama4TextExperts, Llama4TextExperts,
) )
def unpatch():
modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
old_lamma_4_text_experts,
)
return unpatch

View File

@@ -0,0 +1,78 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -81,6 +81,11 @@ def setup_model_and_tokenizer(
# Apply freezing if specified # Apply freezing if specified
if cfg.unfrozen_parameters: if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters) freeze_layers_except(model, cfg.unfrozen_parameters)
if any(
any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters
):
model.enable_input_require_grads()
return model, tokenizer, peft_config, processor return model, tokenizer, peft_config, processor

View File

@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any
import numpy as np import numpy as np
import torch import torch
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None model: Any | None = None
padding: Union[bool, str, PaddingStrategy] = True padding: bool | str | PaddingStrategy = True
max_length: Optional[int] = None max_length: int | None = None
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: int | None = None
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sequence_parallel_degree: int = 1 sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self): def __post_init__(self):
if self.sequence_parallel_degree > 1: if self.sequence_parallel_degree > 1:
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
Sliced batch dictionary. Sliced batch dictionary.
""" """
# Get local (start, end) for sequence parallelism slicing # Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1] total_seq_len = batch["input_ids"].size(1)
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for ring attention calculation # Update params for varlen ring attention calculation
update_ring_attn_params(batch=batch) if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing # Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] for key in batch:
for key in keys_to_slice: if batch[key].size(1) == total_seq_len:
if key in batch: if self.ring_attn_func in [
batch[key] = batch[key][:, start:end] RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch return batch

View File

@@ -332,16 +332,23 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
def gen_from_iter_ds(_ds, _=None): def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
yield from _ds """Generator function to correctly splice the dataset for each worker"""
for i, item in enumerate(_ds):
if i % num_workers[0] == worker_id[0]:
yield item
ds_from_iter = Dataset.from_generator( ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset), functools.partial(gen_from_iter_ds, dataset),
features=dataset.features, features=dataset.features,
num_proc=cfg.dataset_processes, num_proc=num_workers,
split=split, split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))}, gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
) )
ds_from_iter.save_to_disk(str(prepared_ds_path)) ds_from_iter.save_to_disk(str(prepared_ds_path))
else: else:

View File

@@ -2,13 +2,14 @@
module to freeze/unfreeze parameters by name module to freeze/unfreeze parameters by name
""" """
import logging
import re import re
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
from accelerate.logging import get_logger
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.utils.freeze") LOG = get_logger(__name__)
def freeze_layers_except(model, regex_patterns): def freeze_layers_except(model, regex_patterns):
@@ -184,7 +185,7 @@ class LayerNamePattern:
""" """
self.raw_pattern = pattern self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern) name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(name_pattern.replace(".", "\\.")) self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern))
def match(self, name: str) -> bool: def match(self, name: str) -> bool:
""" """

View File

@@ -542,6 +542,17 @@ class ModelLoader:
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils() patch_accelerate_fsdp_utils()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
# patch gemma3 conditional generation forward before loading plugins # patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins # as it could be overridden by plugins
if self.cfg.model_config_type == "llama4": if self.cfg.model_config_type == "llama4":
@@ -644,6 +655,7 @@ class ModelLoader:
register_ring_attn( register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree, sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride, heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
) )
def patch_attention(self) -> None: def patch_attention(self) -> None:
@@ -905,13 +917,6 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention" "flex_attention"
) )
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention: elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
@@ -1115,7 +1120,7 @@ class ModelLoader:
return skip_move_to_device return skip_move_to_device
def ajust_model_config(self) -> None: def adjust_model_config(self) -> None:
if ( if (
hasattr(self.model, "config") hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings") and hasattr(self.model.config, "max_position_embeddings")
@@ -1275,7 +1280,7 @@ class ModelLoader:
else: else:
self.model.tie_weights() self.model.tie_weights()
self.ajust_model_config() self.adjust_model_config()
# log device memory usage # log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ( if hasattr(self.model, "device") and self.model.device.type in (

View File

@@ -225,6 +225,7 @@ class AxolotlInputConfig(
sdp_attention: bool | None = None sdp_attention: bool | None = None
s2_attention: bool | None = None s2_attention: bool | None = None
flex_attention: bool | None = None flex_attention: bool | None = None
flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = None flash_attention: bool | None = None
flash_attn_cross_entropy: bool | None = None flash_attn_cross_entropy: bool | None = None
flash_attn_rms_norm: bool | None = None flash_attn_rms_norm: bool | None = None
@@ -258,6 +259,7 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None heads_k_stride: int | None = None
ring_attn_func: str | None = None
special_tokens: SpecialTokensConfig | None = None special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None tokens: list[str] | None = None
@@ -1146,7 +1148,7 @@ class AxolotlInputConfig(
return data return data
@field_validator("sequence_parallel_degree", mode="before") @field_validator("sequence_parallel_degree", mode="after")
@classmethod @classmethod
def check_sequence_parallel_degree(cls, value, info): def check_sequence_parallel_degree(cls, value, info):
if not value: if not value:
@@ -1158,9 +1160,12 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if not info.data["micro_batch_size"] == 1: if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 " "micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
) )
@@ -1187,6 +1192,34 @@ class AxolotlInputConfig(
return value return value
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_muon_deepspeed_fsdp(cls, data): def check_muon_deepspeed_fsdp(cls, data):
@@ -1276,11 +1309,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
): ):
capabilities = data.get("capabilities") capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None is_fsdp = data.get("fsdp") is not None
is_fsdp2 = (
if capabilities and capabilities.get("n_gpu", 0) > 1: data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
if is_fsdp: if is_fsdp:
raise ValueError( raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP." "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
) )
return data return data

View File

@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -235,7 +236,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type in ["mamba", "gemma3"]: drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
@@ -625,6 +627,12 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters. on the provided parameters.
""" """
if (
cfg.torch_compile
and cfg.fsdp_config
and str(cfg.fsdp_config.fsdp_version) == "2"
):
patch_evaluation_loop_for_fsdp2()
if cfg.rl: if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref trainer_builder.model_ref = model_ref

View File

@@ -496,6 +496,12 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"] return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(scope="session", autouse=True)
def download_tiny_llama_7m_model():
# download the model
return snapshot_download_w_retry("axolotl-ai-internal/llama-7m", repo_type="model")
# # pylint: disable=redefined-outer-name,unused-argument # # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures( # def test_load_fixtures(
# download_smollm2_135m_model, # download_smollm2_135m_model,

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
) )

View File

@@ -56,11 +56,12 @@ class TestPackedFlex:
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", "optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 5, "max_steps": 2,
"use_tensorboard": True, "use_tensorboard": True,
"save_strategy": "no", "save_strategy": "no",
} }
@@ -88,5 +89,5 @@ class TestPackedFlex:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -177,6 +177,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
} }
vllm_process_id = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
@@ -264,6 +265,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
} }
vllm_process_id = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,

View File

@@ -96,5 +96,5 @@ class TestMultiGPUGemma3:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 1.8, "Train loss (%s) is too high"
) )

View File

@@ -43,7 +43,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
@@ -94,7 +94,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -105,7 +105,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": True, "sample_packing": True,
"eval_sample_packing": False, "eval_sample_packing": False,
@@ -159,14 +159,14 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
def test_dpo_lora_ddp(self, temp_dir): def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
@@ -244,7 +244,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
@@ -326,7 +326,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.01, "val_set_size": 0.01,
"special_tokens": { "special_tokens": {
@@ -385,7 +385,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -396,7 +396,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -457,7 +457,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@require_torch_2_6_0 @require_torch_2_6_0
@@ -475,7 +475,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
@@ -538,7 +538,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.1, "Train loss (%s) is too high"
) )
def test_fsdp_qlora_prequant_packed(self, temp_dir): def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -618,15 +618,9 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 2],
@@ -660,7 +654,7 @@ class TestMultiGPULlama:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -708,7 +702,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -734,7 +728,7 @@ class TestMultiGPULlama:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -782,7 +776,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -808,7 +802,7 @@ class TestMultiGPULlama:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -856,7 +850,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.skip( @pytest.mark.skip(
@@ -866,7 +860,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"fix_untrained_tokens": True, "fix_untrained_tokens": True,
"sequence_len": 512, "sequence_len": 512,
"val_set_size": 0.0, "val_set_size": 0.0,
@@ -923,5 +917,5 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 4.0, "Train loss (%s) is too high"
) )

View File

@@ -80,7 +80,7 @@ class TestMultiGPURay:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@require_torch_lt_2_6_0 @require_torch_lt_2_6_0
@@ -138,5 +138,5 @@ class TestMultiGPURay:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )

View File

@@ -3,6 +3,7 @@
import os import os
from pathlib import Path from pathlib import Path
import pytest
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism: class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled""" """Test case for training with sequence parallelism enabled"""
def test_sequence_parallel_training(self, temp_dir): def _run_sequence_parallel_test(
# pylint: disable=duplicate-code self,
temp_dir,
sample_packing=True,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",
"sample_packing": True, "sample_packing": sample_packing,
"eval_sample_packing": True, "eval_sample_packing": sample_packing,
"pad_to_sequence_len": True, "pad_to_sequence_len": pad_to_sequence_len,
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 8, "max_steps": 8,
"micro_batch_size": 1, "micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True, "use_tensorboard": True,
"sequence_parallel_degree": 2, "sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
} }
) )
@@ -84,5 +93,37 @@ class TestSequenceParallelism:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.6, "Train loss (%s) is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
self,
temp_dir,
sample_packing,
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
) )

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration(): def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model.""" """Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto" "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
) )
peft_config = get_peft_config( peft_config = get_peft_config(
{ {
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures.""" """Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype # Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto" model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
) )
# Apply LoRA configuration # Apply LoRA configuration

View File

@@ -86,5 +86,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 1.5, "Train loss (%s) is too high"
) )

View File

@@ -73,7 +73,10 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state self, mock_world_size, mock_rank, mock_new_group, partial_state
): ):
"""Test that ring attention groups are created correctly.""" """Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks # Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total mock_world_size.return_value = 8 # 8 GPUs total
@@ -82,7 +85,11 @@ class TestRingAttention:
mock_new_group.return_value = mock_group mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4 # Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1) register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments # Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2 assert mock_new_group.call_count == 2

View File

@@ -80,7 +80,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )
def test_unsloth_llama_qlora_unpacked(self, temp_dir): def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -130,7 +130,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -185,5 +185,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -69,5 +69,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -84,5 +84,5 @@ class TestPretrainLlama:
temp_dir + "/runs", temp_dir + "/runs",
"train/train_loss", "train/train_loss",
loss_threshold, loss_threshold,
"Train Loss is too high", "Train Loss (%s) is too high",
) )

View File

@@ -68,5 +68,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -73,6 +73,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.5, "Train loss (%s) is too high"
) )
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,7 +8,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" MODEL_NAME = "axolotl-ai-internal/llama-7m"
@fixture() @fixture()
@@ -36,7 +36,7 @@ One day, a little fish named Fin was swimming near the shore. He saw a big crab
""" """
result = metric.compute(model, [sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 7.41 assert round(ppl, 2) == 75.14
def test_perplexity_short(model, metric): def test_perplexity_short(model, metric):
@@ -44,4 +44,4 @@ def test_perplexity_short(model, metric):
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun." sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute(model, [sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 10.33 assert round(ppl, 2) == 70.54