Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
0aa7c72c59 bump transformers to 4.51.3 2025-04-14 07:49:18 -07:00
104 changed files with 416 additions and 3169 deletions

View File

@@ -1,14 +0,0 @@
[run]
source = axolotl
omit =
*/tests/*
setup.py
[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
pass
raise ImportError

View File

@@ -46,18 +46,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -29,13 +29,8 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -98,11 +93,6 @@ jobs:
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -148,7 +138,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -45,13 +45,6 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -74,7 +67,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -147,7 +147,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -102,17 +102,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -242,7 +234,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -270,12 +261,6 @@ jobs:
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -296,7 +281,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -9,7 +9,6 @@
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<br/>
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>

View File

@@ -3,53 +3,10 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \
--ignore=tests/patched/ \
--ignore=tests/cli \
/workspace/axolotl/tests/ \
--cov=axolotl
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/patched/lora_kernels \
--cov=axolotl \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \
--cov-append
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/solo/ \
--cov=axolotl \
--cov-append
# Run integration tests with coverage append
pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/integrations/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 /workspace/axolotl/tests/cli \
--cov=axolotl \
--cov-append
# Run remaining e2e tests with coverage append and final report
pytest -v --durations=10 \
--ignore=tests/e2e/solo/ \
--ignore=tests/e2e/patched/ \
--ignore=tests/e2e/multigpu/ \
--ignore=tests/e2e/integrations/ \
--ignore=tests/cli \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:e2e-coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 /workspace/axolotl/tests/cli
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/

View File

@@ -28,7 +28,6 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -29,7 +29,6 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -1,23 +1,6 @@
#!/bin/bash
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/

View File

@@ -1,56 +0,0 @@
codecov:
require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage:
precision: 2
round: down
range: "70...100"
status:
project:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
patch:
default:
# basic
target: auto
threshold: 0%
base: auto
# advanced
branches: null
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
flags: null
paths: null
parsers:
gcov:
branch_detection:
conditional: yes
loop: yes
method: no
macro: no
comment:
layout: "reach,diff,flags,files,footer"
behavior: default
require_changes: no
require_base: no
require_head: yes
github_checks:
annotations: false

View File

@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,17 +199,6 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -693,9 +693,6 @@ sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
# in the sample packing case, and "batch_ring" in the non-sample packing case.
ring_attn_func:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -19,12 +19,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}

View File

@@ -27,9 +27,6 @@ To enable sequence parallelism, add the following to your configuration file:
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

View File

@@ -1,62 +0,0 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,36 +1,16 @@
# Llama 4 by Meta AI
## Flash Attention vs Flex Attention
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
## Available Examples
### Llama 4 Scout 17Bx16Experts (109B)
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
Flex Attention
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
[//]: # (Flash Attention &#40;Do not use&#41;)
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1]&#40;./scout-vision-qlora-fsdp.yaml&#41;)
[//]: # (- [Text Single GPU &#40;H100&#41; QLoRA]&#40;./scout-qlora-single-h100.yaml&#41;)
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1]&#40;./scout-qlora-fsdp1.yaml&#41;)
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)

View File

@@ -1,86 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head
# - embed_tokens
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -1,85 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.liger.LigerPlugin
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
# - lm_head # needed if modifying vocabulary
# - embed_tokens
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
chat_template: llama4
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
torch_compile: true
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
warmup_steps: 20
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -1,89 +0,0 @@
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
model_type: Llama4ForConditionalGeneration
processor_type: Llama4Processor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
sequence_len: 4096
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
llama4_linearized_experts: true # use Axolotl's customized model
load_in_4bit: true
adapter: qlora
lora_r: 32
lora_alpha: 64
lora_target_modules:
- self_attn.q_proj
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- shared_expert.gate_proj
- shared_expert.up_proj
- shared_expert.down_proj
- vision_adapter.mlp.fc1
- vision_adapter.mlp.fc2
# - experts.gate_projs.[0-9]+$
# - experts.up_projs.[0-9]+$
# - experts.down_projs.[0-9]+$
lora_modules_to_save:
- lm_head
- embed_tokens
chat_template: llama4
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 1e-4
bf16: true
tf32: true
logging_steps: 1
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- auto_wrap
- full_shard
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>

View File

@@ -1,6 +1,6 @@
pre-commit
black
mypy
pre-commit
types-requests
quartodoc
jupyter

View File

@@ -1,8 +1,5 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-xdist
pytest-retry
pytest-sugar
pytest-xdist
tbparse

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.8
liger-kernel==0.5.6
# END section
packaging==23.2
@@ -19,7 +19,6 @@ datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2
hf_transfer

View File

@@ -25,5 +25,5 @@ if cce_spec:
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
)

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.6.0" # default to torch 2.6
torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,16 +64,10 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"]
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: Optional[int] = field(
default=None,
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: Optional[str] = field(
default=None, # nosec B104
host: str = field(
default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: Optional[int] = field(
default=None,
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(

View File

@@ -1,156 +0,0 @@
"""
CLI tool to delinearize quantized/Linearized Llama-4 models.
"""
import os
from pathlib import Path
from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
def iter_convert_patched_to_hf(model_state_dict, num_experts) -> Generator:
keys = list(model_state_dict.keys())
for key in keys:
if ".feed_forward.experts." not in key:
yield key, model_state_dict[key]
if ".feed_forward.experts.gate_projs" in key:
# gate gets fused with up so skip the yield on this and we'll fuse it when asking for the up
continue
if ".feed_forward.experts.up_projs" in key:
if ".feed_forward.experts.up_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".up_projs.0.")[0]
key = f"{prefix}.gate_up_proj"
# grab all the up_projs and gate_projs across all experts
gate_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.gate_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
up_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.up_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
gate_up_proj = torch.cat((gate_stacked, up_stacked), dim=-1)
del gate_stacked, up_stacked
yield key, gate_up_proj
else:
del model_state_dict[key]
continue
if ".feed_forward.experts.down_projs" in key:
if ".feed_forward.experts.down_projs.0." in key:
# handle the re-shape and fusing of gate and up, and conversion from linear to parameter
prefix = key.split(".down_projs.0.")[0]
key = f"{prefix}.down_proj"
# grab all the down_projs across all experts
down_stacked = torch.stack(
[
model_state_dict[
f"{prefix}.down_projs.{expert_idx}.weight"
].transpose(0, 1)
for expert_idx in range(num_experts)
]
)
yield key, down_stacked
else:
del model_state_dict[key]
continue
def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
"""
Convert a patched HF format Llama4 model (with separated projections)
back to the original HF format (with fused projections).
Args:
model: Path to the patched HF model
output: Path to save the converted model
"""
print(f"Loading model from {model}")
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
unpatch_llama4 = patch_llama4_linearized_modeling()
from transformers import Llama4ForConditionalGeneration
model_ = Llama4ForConditionalGeneration.from_pretrained(
model, torch_dtype=torch.bfloat16
)
processor = AutoProcessor.from_pretrained(model)
processor.save_pretrained(output)
device = model_.device.type
if device == "cuda":
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_config = model_.config
config = model_.config.get_text_config()
# Get key dimensions from the config
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
num_experts = config.num_local_experts
print(
f"Model dimensions: hidden_size={hidden_size}, intermediate_size={intermediate_size}, num_experts={num_experts}"
)
# Create output directory if it doesn't exist
os.makedirs(output, exist_ok=True)
# Get state dict
state_dict = model_.state_dict()
del model_
# Create a new state dict for the converted model
converted_state_dict = {}
# First, copy all keys that don't need modification
for key, value in iter_convert_patched_to_hf(state_dict, num_experts):
converted_state_dict[key] = value
del state_dict
if device == "cuda":
torch.cuda.empty_cache()
print("State dict converted.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
# Ideally re-load the model import to load the converted state dict
# Save the converted model
with init_empty_weights():
unpatch_llama4()
model_ = Llama4ForConditionalGeneration(model_config)
if device == "cuda":
print("State dict loaded into model.")
print(
f"peak memory allocated: {torch.cuda.max_memory_allocated() / 1024**2} MB"
)
print(f"peak memory reserved: {torch.cuda.max_memory_reserved() / 1024**2} MB")
model_.load_state_dict(converted_state_dict, strict=False, assign=True)
print(f"Saving converted model to {output}...")
model_.save_pretrained(output)
print(f"Model successfully converted and saved to {output}")
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -330,15 +330,6 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
cli.add_command(lm_eval)

View File

@@ -40,7 +40,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.warning("Error raised: %s", e)
model.generation_config.do_sample = True
model.config.use_cache = True
if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")

View File

@@ -14,7 +14,6 @@ from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@@ -126,7 +125,7 @@ def load_preference_datasets(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:

View File

@@ -84,7 +84,7 @@ from axolotl.utils.collators import (
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -538,6 +538,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
@@ -774,7 +776,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -930,6 +931,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,
@@ -1009,8 +1012,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.seed:
training_args_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
@@ -1037,24 +1038,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
@@ -1062,13 +1057,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl is RLType.ORPO:
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.KTO:
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1082,14 +1077,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl is RLType.GRPO:
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl is RLType.IPO:
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
@@ -1126,33 +1121,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
trainer_kwargs = {}
if self.cfg.rl is RLType.IPO:
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
if self.cfg.dpo_label_smoothing:
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl is RLType.GRPO:
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl is RLType.ORPO:
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.KTO:
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl is RLType.SIMPO:
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
@@ -1160,33 +1155,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
else:
trainer_kwargs["processing_class"] = self.tokenizer
dpo_trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
**dpo_trainer_kwargs,
)
if self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)
return trainer
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):

View File

@@ -371,15 +371,13 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
loss = super().compute_loss(
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -3,7 +3,6 @@ DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
from axolotl.utils.schemas.enums import RLType
class DPOStrategy:
@@ -24,7 +23,7 @@ class DPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl is RLType.IPO:
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:

View File

@@ -11,4 +11,6 @@ from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -1,124 +0,0 @@
"""
Repeat random sampler (akin to the one implemented in
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
sequence parallelism functionality; i.e., duplicating data across ranks in the same
sequencee parallel group.
"""
from typing import Sized
import torch
from torch.utils.data import Sampler
class SequenceParallelRepeatRandomSampler(Sampler):
"""
Sampler for GRPO training with sequence parallelism that ensures:
1. Ranks in the same sequence parallel group receive identical data
2. Each index is repeated multiple times for sampling different completions
3. Entire batches are repeated for reuse in multiple updates
"""
def __init__(
self,
dataset: Sized,
mini_repeat_count: int,
world_size: int,
rank: int,
batch_size: int = 1,
repeat_count: int = 1,
sequence_parallel_degree: int = 1,
shuffle: bool = True,
seed: int = 0,
drop_last: bool = False,
):
self.dataset = dataset
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.shuffle = shuffle
self.seed = seed
self.drop_last = drop_last
self.epoch = 0
self.world_size = world_size
self.rank = rank
# Sequence parallelism parameters
self.sequence_parallel_degree = sequence_parallel_degree
self.num_sp_groups = world_size // sequence_parallel_degree
self.sp_group_id = rank // sequence_parallel_degree
# Adjust dataset size for distributed sampling
self.num_samples = len(self.dataset)
self.total_size = self.num_samples
# Calculate effective number of samples per SP group
if (
self.drop_last
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
):
# Drop last incomplete batch if drop_last is True
self.num_samples_per_sp_group = (
self.total_size // self.batch_size // self.num_sp_groups
) * self.batch_size
else:
# Round up to include last batch if drop_last is False
self.num_samples_per_sp_group = (
(self.total_size + self.batch_size * self.num_sp_groups - 1)
// (self.batch_size * self.num_sp_groups)
* self.batch_size
)
def __iter__(self):
# Deterministically shuffle based on epoch and seed
if self.shuffle:
# Use same seed for all ranks in the same SP group
g = torch.Generator()
seed_value = self.seed + self.epoch + self.sp_group_id * 10000
g.manual_seed(seed_value)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible by batch_size
if len(indices) % self.batch_size != 0:
padding = indices[: self.batch_size - len(indices) % self.batch_size]
indices += padding
# Subsample based on SP group ID
# Each SP group gets distinct batches of data
batch_indices = []
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
start_idx = i + self.sp_group_id * self.batch_size
end_idx = min(start_idx + self.batch_size, len(indices))
if start_idx < len(indices):
for j in range(self.batch_size):
if start_idx + j < end_idx:
batch_indices.append(indices[start_idx + j])
# Make sure batch_indices is exactly batch_size * num_batches_per_sp_group
if self.drop_last:
num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size
target_len = self.batch_size * num_batches_per_sp_group
if len(batch_indices) > target_len:
batch_indices = batch_indices[:target_len]
# Apply the GRPO repeat pattern
final_indices = []
for _ in range(self.repeat_count):
for idx in batch_indices:
for _ in range(self.mini_repeat_count):
final_indices.append(idx)
return iter(final_indices)
def __len__(self):
# Total length including all repetitions
return (
self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count
)
def set_epoch(self, epoch):
"""Sets the epoch for this sampler"""
self.epoch = epoch

View File

@@ -1,279 +1,26 @@
"""Axolotl GRPO trainer"""
"""
Axolotl GRPO trainer
"""
# pylint: disable=too-many-lines,duplicate-code
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
import torch
import torch.distributed as dist
from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
)
from datasets import Dataset, IterableDataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
Sampler,
)
from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import (
is_deepspeed_available,
is_rich_available,
)
from trl.models import (
unwrap_model_for_generation,
)
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc
from trl.trainer.utils import (
pad,
print_prompt_completions_sample,
selective_log_softmax,
)
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
if is_wandb_available():
import wandb
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(
self,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: (
Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None
) = None,
processing_class: PreTrainedTokenizerBase | None = None,
reward_processing_classes: (
PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None
) = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[
torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None
] = (None, None),
peft_config: "PeftConfig | None" = None,
):
# First call the superclass constructor with all arguments
super().__init__(
model=model,
reward_funcs=reward_funcs,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
reward_processing_classes=reward_processing_classes,
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)
# Now execute your custom logic
# Get number of SP groups (number of processes divided by SP degree)
num_processes = self.accelerator.num_processes
num_sp_groups = num_processes // self.args.sequence_parallel_degree
# Calculate batch size per SP group (not per process)
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
possible_values = [
n_gen
for n_gen in range(2, sp_group_batch_size + 1)
if (sp_group_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The batch size per SP group ({num_sp_groups} x "
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
f"the number of generations per prompt ({self.num_generations}). Given "
"the current configuration, the valid values for the number of "
f"generations are: {possible_values}."
)
if self.args.eval_strategy != "no":
# If sequence parallelism is enabled, calculate batch size per SP group
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
possible_values = [
n_gen
for n_gen in range(2, sp_group_eval_batch_size + 1)
if (sp_group_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
f"must be evenly divisible by the number of generations per prompt "
f"({self.num_generations}). Given the current eval batch size, "
f"the valid values for the number of generations are: {possible_values}."
)
# Initialize the SP group
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
print("end of trainer init")
def _get_train_sampler(self) -> Sampler:
# Get distributed training info
world_size = dist.get_world_size()
rank = dist.get_rank()
effective_batch_size = (
self.args.per_device_train_batch_size
* world_size
* self.args.gradient_accumulation_steps
)
return SequenceParallelRepeatRandomSampler(
dataset=self.train_dataset,
mini_repeat_count=self.num_generations,
world_size=world_size,
rank=rank,
batch_size=effective_batch_size
// self.num_generations
// self.args.sequence_parallel_degree,
repeat_count=self.num_iterations,
sequence_parallel_degree=self.args.sequence_parallel_degree,
shuffle=True,
seed=self.args.seed,
drop_last=True,
)
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Initialize SP group attributes if sequence parallelism is enabled
if self.args.sequence_parallel_degree > 1:
self.sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=self.sp_group)
self.local_world_size = dist.get_world_size(group=self.sp_group)
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
# Add debug print before any modifications
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
)
# Get sampler and create dataloader
sampler = self._get_train_sampler()
dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False)
return dataloader
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
@@ -320,577 +67,3 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
# def _generate_and_score_completions(
# self, inputs: list[dict[str, torch.Tensor | Any]]
# ) -> dict[str, torch.Tensor | Any]:
# device = self.accelerator.device
# prompts = [x["prompt"] for x in inputs]
# prompts_text = [
# maybe_apply_chat_template(example, self.processing_class)["prompt"]
# for example in inputs
# ]
# prompt_inputs = self.processing_class(
# text=prompts_text,
# return_tensors="pt",
# padding=True,
# padding_side="left",
# add_special_tokens=False,
# )
# # pylint: disable=protected-access
# prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
# prompt_ids, prompt_mask = (
# prompt_inputs["input_ids"],
# prompt_inputs["attention_mask"],
# )
# if self.max_prompt_length is not None:
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# # Generate completions using either vLLM or regular generation
# if self.args.use_vllm:
# # First, have main process load weights if needed
# # pylint: disable=access-member-before-definition
# if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
# self._move_model_to_vllm()
# # pylint: disable=attribute-defined-outside-init
# self._last_loaded_step = self.state.global_step
# all_prompts_text = gather_object(prompts_text)
# if self.accelerator.is_main_process:
# # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# # num_generations outputs for each one. This is faster than generating outputs for each duplicate
# # prompt individually.
# # ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
# ordered_set_of_prompts = all_prompts_text[
# :: self.num_generations * self.args.sequence_parallel_degree
# ]
# with profiling_context(self, "vLLM.generate"):
# completion_ids = self.vllm_client.generate(
# prompts=ordered_set_of_prompts,
# n=self.num_generations,
# repetition_penalty=self.repetition_penalty,
# temperature=self.temperature,
# top_p=self.top_p,
# top_k=-1 if self.top_k is None else self.top_k,
# min_p=0.0 if self.min_p is None else self.min_p,
# max_tokens=self.max_completion_length,
# guided_decoding_regex=self.guided_decoding_regex,
# )
# else:
# completion_ids = [None] * (
# len(all_prompts_text) // self.args.sequence_parallel_degree
# )
# # Broadcast the completions from the main process to all processes
# completion_ids = broadcast_object_list(completion_ids, from_process=0)
# # Determine the appropriate slice based on sequence parallelism
# if self.args.sequence_parallel_degree > 1:
# # Calculate SP group ID (which group of ranks this rank belongs to)
# sp_group_id = self.accelerator.process_index // self.local_world_size
# # Calculate the start index for this SP group
# sp_group_start = sp_group_id * len(prompts) * self.local_world_size
# # All ranks in the same SP group get the same data slice
# process_slice = slice(
# sp_group_start,
# sp_group_start + len(prompts),
# )
# completion_ids = completion_ids[process_slice]
# else:
# # Original behavior for non-sequence parallel case
# process_slice = slice(
# self.accelerator.process_index * len(prompts),
# (self.accelerator.process_index + 1) * len(prompts),
# )
# completion_ids = completion_ids[process_slice]
# # Pad the completions, and concatenate them with the prompts
# completion_ids = [
# torch.tensor(ids, device=device) for ids in completion_ids
# ]
# completion_ids = pad(
# completion_ids, padding_value=self.processing_class.pad_token_id
# )
# else:
# # Regular generation path
# with unwrap_model_for_generation(
# self.model_wrapped,
# self.accelerator,
# gather_deepspeed3_params=self.args.ds3_gather_for_generation,
# ) as unwrapped_model:
# prompt_completion_ids = unwrapped_model.generate(
# prompt_ids,
# attention_mask=prompt_mask,
# generation_config=self.generation_config,
# )
# # Compute prompt length and extract completion ids
# prompt_length = prompt_ids.size(1)
# prompt_ids = prompt_completion_ids[:, :prompt_length]
# completion_ids = prompt_completion_ids[:, prompt_length:]
# prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# # Mask everything after the first EOS token
# is_eos = completion_ids == self.processing_class.eos_token_id
# eos_idx = torch.full(
# (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device
# )
# eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
# sequence_indices = torch.arange(is_eos.size(1), device=device).expand(
# is_eos.size(0), -1
# )
# completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# # Concatenate prompt_mask with completion_mask for logit computation
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
# logits_to_keep = completion_ids.size(
# 1
# ) # we only need to compute the logits for the completion tokens
# with torch.no_grad():
# # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# # computation here, and use per_token_logps.detach() instead.
# if self.num_iterations > 1:
# if self.args.sequence_parallel_degree > 1:
# old_per_token_logps, _ = self._get_per_token_logps_v2(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# old_per_token_logps = super()._get_per_token_logps(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# old_per_token_logps = None
# if self.beta == 0.0:
# ref_per_token_logps = None
# elif self.ref_model is not None:
# if self.args.sequence_parallel_degree > 1:
# ref_per_token_logps, _ = self._get_per_token_logps_v2(
# self.ref_model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# ref_per_token_logps = super()._get_per_token_logps(
# self.ref_model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# with self.accelerator.unwrap_model(self.model).disable_adapter():
# if self.args.sequence_parallel_degree > 1:
# ref_per_token_logps, _ = self._get_per_token_logps_v2(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# else:
# ref_per_token_logps = super()._get_per_token_logps(
# self.model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# )
# # Decode the generated completions
# completions_text = self.processing_class.batch_decode(
# completion_ids, skip_special_tokens=True
# )
# if is_conversational(inputs[0]):
# completions = []
# for prompt, completion in zip(prompts, completions_text):
# bootstrap = (
# prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
# )
# completions.append(
# [{"role": "assistant", "content": bootstrap + completion}]
# )
# else:
# completions = completions_text
# rewards_per_func = torch.zeros(
# len(prompts), len(self.reward_funcs), device=device
# )
# for i, (reward_func, reward_processing_class) in enumerate(
# zip(self.reward_funcs, self.reward_processing_classes)
# ):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# reward_func_name = (
# f"reward {reward_func.config._name_or_path.split('/')[-1]}"
# )
# else:
# # pylint: disable=protected-access
# reward_func_name = reward_func.__name__
# with profiling_context(self, reward_func_name):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# if is_conversational(inputs[0]):
# messages = [
# {"messages": p + c} for p, c in zip(prompts, completions)
# ]
# texts = [
# apply_chat_template(x, reward_processing_class)["text"]
# for x in messages
# ]
# else:
# texts = [p + c for p, c in zip(prompts, completions)]
# reward_inputs = reward_processing_class(
# text=texts,
# return_tensors="pt",
# padding=True,
# padding_side="right",
# add_special_tokens=False,
# )
# # pylint: disable=protected-access
# reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
# with torch.inference_mode():
# rewards_per_func[:, i] = reward_func(**reward_inputs).logits[
# :, 0
# ] # Shape (B*G,)
# else:
# # Repeat all input columns (but "prompt" and "completion") to match the number of generations
# keys = [
# key for key in inputs[0] if key not in ["prompt", "completion"]
# ]
# reward_kwargs = {
# key: [example[key] for example in inputs] for key in keys
# }
# output_reward_func = reward_func(
# prompts=prompts, completions=completions, **reward_kwargs
# )
# # Convert None values to NaN
# output_reward_func = [
# reward if reward is not None else torch.nan
# for reward in output_reward_func
# ]
# rewards_per_func[:, i] = torch.tensor(
# output_reward_func, dtype=torch.float32, device=device
# )
# # If all reward functions return None for a given row, issue a detailed warning
# if torch.isnan(rewards_per_func).all(dim=1).any():
# nan_row_idx = (
# torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
# )
# row_reward_kwargs = {
# key: value[nan_row_idx] for key, value in reward_kwargs.items()
# }
# row_reward_kwargs["prompt"] = prompts[nan_row_idx]
# row_reward_kwargs["completion"] = completions[nan_row_idx]
# warnings.warn(
# f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
# "Please ensure that at least one reward function returns a valid reward."
# )
# # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# # completions may be distributed across processes
# rewards_per_func = gather(rewards_per_func)
# # Apply weights to each reward function's output and sum
# rewards = (
# rewards_per_func * self.reward_weights.to(device).unsqueeze(0)
# ).nansum(dim=1)
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
# self.num_generations, dim=0
# )
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(
# self.num_generations, dim=0
# )
# advantages = rewards - mean_grouped_rewards
# if self.args.scale_rewards:
# advantages = advantages / (std_grouped_rewards + 1e-4)
# # Slice to keep only the local part of the data
# process_slice = slice(
# self.accelerator.process_index * len(prompts),
# (self.accelerator.process_index + 1) * len(prompts),
# )
# advantages = advantages[process_slice]
# # Log the metrics
# mode = "eval" if self.control.should_evaluate else "train"
# if mode == "train":
# # pylint: disable=no-member
# self._total_train_tokens += (
# self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
# )
# # pylint: disable=no-member
# self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
# completion_length = (
# self.accelerator.gather_for_metrics(completion_mask.sum(1))
# .float()
# .mean()
# .item()
# )
# self._metrics[mode]["completion_length"].append(completion_length)
# # Calculate mean reward per function, but only for samples where the function was applied
# for i, reward_func in enumerate(self.reward_funcs):
# if isinstance(
# reward_func, nn.Module
# ): # Module instead of PretrainedModel for compat with compiled models
# reward_func_name = reward_func.config._name_or_path.split("/")[-1]
# else:
# # pylint: disable=protected-access
# reward_func_name = reward_func.__name__
# # Only calculate mean for samples where this reward function was applied (non-NaN values)
# mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
# self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
# self._metrics[mode]["reward"].append(rewards.mean().item())
# self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
# if (
# self.log_completions
# and self.state.global_step % self.args.logging_steps == 0
# ):
# prompts_to_log = gather_object(prompts_text)
# completions_to_log = gather_object(completions_text)
# rewards_to_log = rewards.tolist()
# if self.accelerator.is_main_process:
# if is_rich_available():
# print_prompt_completions_sample(
# prompts_to_log,
# completions_to_log,
# rewards_to_log,
# self.state.global_step,
# )
# if (
# self.args.report_to
# and "wandb" in self.args.report_to
# and wandb.run is not None
# ):
# import pandas as pd
# # For logging
# table = {
# "step": [str(self.state.global_step)] * len(rewards),
# "prompt": prompts_to_log,
# "completion": completions_to_log,
# "reward": rewards.tolist(),
# }
# df = pd.DataFrame(table)
# wandb.log({"completions": wandb.Table(dataframe=df)})
# return {
# "prompt_ids": prompt_ids,
# "prompt_mask": prompt_mask,
# "completion_ids": completion_ids,
# "completion_mask": completion_mask,
# "old_per_token_logps": old_per_token_logps,
# "ref_per_token_logps": ref_per_token_logps,
# "advantages": advantages,
# }
# def _get_per_token_logps_v2(
# self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None
# ):
# # Pad sequence to be divisible by SP degree if needed
# total_seq_len = input_ids.shape[1]
# if total_seq_len % self.local_world_size != 0:
# pad_len = self.local_world_size - (total_seq_len % self.local_world_size)
# pad_token_id = self.processing_class.pad_token_id or 0
# # Pad input_ids and attention_mask
# padding = torch.full(
# (input_ids.shape[0], pad_len),
# pad_token_id,
# dtype=input_ids.dtype,
# device=input_ids.device,
# )
# input_ids = torch.cat([input_ids, padding], dim=1)
# attn_padding = torch.zeros(
# (attention_mask.shape[0], pad_len),
# dtype=attention_mask.dtype,
# device=attention_mask.device,
# )
# attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# if completion_mask is not None:
# completion_mask = torch.cat([completion_mask, attn_padding], dim=1)
# total_seq_len += pad_len
# logits_to_keep += pad_len
# # Split the sequence
# slice_size = total_seq_len // self.local_world_size
# start = self.local_rank * slice_size
# end = start + slice_size
# # Get our slice
# input_ids_slice = input_ids[:, start:end]
# attention_mask_slice = attention_mask[:, start:end]
# # Calculate where our slice starts and ends relative to the completion tokens
# local_completion_mask = None
# prompt_len = input_ids.size(1) - logits_to_keep
# if start >= prompt_len:
# # Slice starts within the completion section
# start_in_completion = start - prompt_len
# end_in_completion = min(end - prompt_len, logits_to_keep)
# local_logits_to_keep = end_in_completion - start_in_completion
# if completion_mask is not None:
# local_completion_mask = completion_mask[
# :, start_in_completion:end_in_completion
# ]
# elif end <= prompt_len:
# # Slice is entirely within the prompt section (no completion tokens)
# local_logits_to_keep = 0
# if completion_mask is not None:
# local_completion_mask = torch.zeros(
# (completion_mask.size(0), 0), device=completion_mask.device
# )
# else:
# # Slice contains the boundary between prompt and completion
# start_in_completion = 0
# end_in_completion = min(end - prompt_len, logits_to_keep)
# local_logits_to_keep = end_in_completion - start_in_completion
# if completion_mask is not None:
# local_completion_mask = completion_mask[
# :, start_in_completion:end_in_completion
# ]
# # Get logits with enough context to compute log probs
# logits = model(
# input_ids=input_ids_slice,
# attention_mask=attention_mask_slice,
# logits_to_keep=local_logits_to_keep + 1,
# ).logits
# # Only the last rank that contains completion tokens needs to remove the last logit
# is_last_rank_with_completions = (
# self.local_rank == self.local_world_size - 1 # Last rank overall
# or end
# >= prompt_len
# + logits_to_keep # Our slice includes the last completion token
# )
# if is_last_rank_with_completions:
# logits = logits[:, :-1]
# if local_completion_mask is not None:
# local_completion_mask = local_completion_mask[:, :-1]
# local_logits_to_keep -= 1
# if start >= prompt_len:
# # For ranks where slice is all completion tokens,
# # we need to offset to match the logits (which predict the next token)
# offset = 1 # Skip the first token as it's predicted by the last token of the previous rank
# local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
# else:
# # For the rank that contains the prompt-completion boundary,
# # we need to take completion tokens only
# offset = prompt_len - start # Where completions start in our slice
# local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep]
# logits = logits[
# :, -local_logits_to_keep:
# ] # Take only logits for completion tokens
# logits = logits / self.temperature
# per_token_logps = selective_log_softmax(logits, local_input_ids)
# return per_token_logps, local_completion_mask
# # pylint: disable=unused-argument
# @profiling_decorator
# def compute_loss(
# self, model, inputs, return_outputs=False, num_items_in_batch=None
# ):
# if return_outputs:
# raise ValueError("The GRPOTrainer does not support returning outputs")
# # Unpack inputs
# prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
# completion_ids, completion_mask = (
# inputs["completion_ids"],
# inputs["completion_mask"],
# )
# prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# logits_to_keep = completion_ids.size(1)
# if self.args.sequence_parallel_degree > 1:
# per_token_logps, completion_mask = self._get_per_token_logps_v2(
# model,
# prompt_completion_ids,
# attention_mask,
# logits_to_keep,
# completion_mask,
# )
# else:
# per_token_logps = super()._get_per_token_logps(
# model, prompt_completion_ids, attention_mask, logits_to_keep
# )
# # Compute the KL divergence between the model and the reference model
# if self.beta != 0.0:
# ref_per_token_logps = inputs["ref_per_token_logps"]
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps)
# - (ref_per_token_logps - per_token_logps)
# - 1
# )
# # Compute the loss
# advantages = inputs["advantages"]
# # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
# # and use per_token_logps.detach() instead.
# old_per_token_logps = (
# inputs["old_per_token_logps"]
# if self.num_iterations > 1
# else per_token_logps.detach()
# )
# coef_1 = torch.exp(per_token_logps - old_per_token_logps)
# coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
# per_token_loss1 = coef_1 * advantages.unsqueeze(1)
# per_token_loss2 = coef_2 * advantages.unsqueeze(1)
# per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
# if self.beta != 0.0:
# per_token_loss = per_token_loss + self.beta * per_token_kl
# loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# # Log metrics
# mode = "eval" if self.control.should_evaluate else "train"
# if self.beta != 0.0:
# mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
# self._metrics[mode]["kl"].append(
# self.accelerator.gather_for_metrics(mean_kl).mean().item()
# )
# is_clipped = (per_token_loss1 < per_token_loss2).float()
# clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
# self._metrics[mode]["clip_ratio"].append(
# self.accelerator.gather_for_metrics(clip_ratio).mean().item()
# )
# return loss

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,144 +1,16 @@
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
"""Module for Axolotl trainer sequence parallelism mixin"""
import functools
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
def _handle_logits_to_keep(
logits_to_keep,
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
total_seq_len: int,
):
"""
Handle logits_to_keep parameter for sequence parallelism.
Args:
logits_to_keep: Integer or tensor indicating which positions to compute logits
for.
local_rank: Rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: Ring attention function being used.
total_seq_len: Full sequence length.
Returns:
Adjusted logits_to_keep appropriate for this rank's sharded sequence
"""
print("start of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep)
# No transformation needed if logits_to_keep is None
if logits_to_keep is None:
return None
assert isinstance(
logits_to_keep, int
), "sequence parallelism currently only supports integer logits_to_keep"
assert ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s"
# For standard sharding, each rank gets a contiguous chunk
chunk_size = total_seq_len // local_world_size
start_idx = local_rank * chunk_size
end_idx = start_idx + chunk_size
# Check if logits_to_keep is in this rank's range
if start_idx <= logits_to_keep < end_idx:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), logits_to_keep - start_idx)
return logits_to_keep - start_idx
else:
print("end of _handle_logits_to_keep")
print(dist.get_rank(), -1)
return -1
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
ring_attn_func: The ring attention function to use.
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
if key == "logits_to_keep":
batch[key] = _handle_logits_to_keep(
logits_to_keep=batch[key],
local_rank=local_rank,
local_world_size=local_world_size,
ring_attn_func=ring_attn_func,
total_seq_len=total_seq_len,
)
return batch
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
@@ -215,160 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
print("start of sequence_parallel_post_hook")
# Gather the sharded outputs
output = self.gather_outputs(output)
print("end of sequence_parallel_post_hook")
return output
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -9,8 +9,6 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:
@@ -220,12 +218,6 @@ class AxolotlTrainingMixins:
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
# multi-modal section

View File

@@ -12,14 +12,12 @@ See https://github.com/apple/ml-cross-entropy
Run the following command to install `cut_cross_entropy[transformers]` if you don't have it already.
- If you are in dev environment
```bash
# if you are in dev environment
python scripts/cutcrossentropy_install.py | sh
```
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
# if you are not in dev environment
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
```
## Usage
@@ -47,8 +45,6 @@ cut_cross_entropy: true
- qwen2
- cohere
- cohere2
- glm
- glm4
## Citation

View File

@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
)

View File

@@ -1,57 +0,0 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -165,7 +165,7 @@ def cce_forward(
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None, # type: ignore
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -254,7 +254,7 @@ def cce_forward_multimodal(
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
@@ -263,13 +263,13 @@ def cce_forward_multimodal(
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
original_inputs_embeds_shape = inputs_embeds.shape
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
final_mask_1d = final_mask[..., 0].reshape(-1)

View File

@@ -20,10 +20,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
@@ -49,8 +45,6 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"qwen2": patch_qwen2,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
}

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2
- gemma
- gemma2
- gemma3
- gemma3 (partial support, no support for FLCE yet)
- granite
- jamba
- llama

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,

View File

@@ -49,7 +49,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
)
sharded_sd[param_name] = sharded_tensor
model.load_state_dict(sharded_sd, assign=True)
model.load_state_dict(sharded_sd)
def patch_accelerate_fsdp_utils():

View File

@@ -7,11 +7,12 @@ import torch
import transformers
def patch_flex_wrapper(**flex_attn_compile_kwargs):
def patch_flex_wrapper():
# TODO remove this patch when transformers#37285 is merged and in a release
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_below_4_51 = transformers.__version__ < "4.51.0"
if not is_torch_2_6:
if not (is_torch_2_6 and is_transformers_below_4_51):
return
from torch.nn.attention.flex_attention import flex_attention
@@ -31,24 +32,17 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def del_singleton(cls):
cls._instance = None
@torch.compiler.disable(recursive=False)
def __init__(self, training):
def __init__(self):
"""
Initialize or update the singleton instance.
"""
self.training = None
if not self._is_flex_compiled or training != self.training:
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
if not self._is_flex_compiled:
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
dynamic=False,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
)
self._is_flex_compiled = True
@@ -56,22 +50,15 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
return self._compiled_flex_attention
transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention
setattr(
sys.modules["transformers.integrations.flex_attention"],
"WrappedFlexAttention",
WrappedFlexAttention,
)
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
is_transformers_eq_4_51 = transformers.__version__ == "4.51.0"
if not is_torch_2_6:
if not (is_torch_2_6 and is_transformers_eq_4_51):
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
@@ -117,16 +104,14 @@ def patch_flex_make_mask():
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
attention_mask_2d, value=0, pad=(0, key_length)
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
@@ -153,18 +138,6 @@ def patch_flex_make_mask():
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
@@ -172,10 +145,10 @@ def patch_flex_make_mask():
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
@@ -187,16 +160,11 @@ def patch_flex_make_mask():
)
for n in tuple(sys.modules):
if ".modeling_" in n:
if ".modeling_" in n and "llama4" not in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[n].make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)
setattr(
sys.modules[n],
"make_flex_block_causal_mask",
patched_make_flex_block_causal_mask,
)
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask

View File

@@ -12,12 +12,10 @@ from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
configure_logging()
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
@@ -42,11 +40,7 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
@@ -54,9 +48,6 @@ def register_ring_attn(
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
@@ -67,9 +58,7 @@ def register_ring_attn(
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
rank = dist.get_rank()
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
@@ -79,8 +68,10 @@ def register_ring_attn(
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups
# Detailed logging of group formation
rank = dist.get_rank()
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
@@ -101,37 +92,35 @@ def register_ring_attn(
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
if heads_k_stride is None:
heads_k_stride = 1
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func in [
RingAttnFunc.BATCH_RING,
RingAttnFunc.BATCH_ZIGZAG,
RingAttnFunc.BATCH_STRIPE,
]:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
def update_ring_attn_params(position_ids: torch.Tensor | None):
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
position_ids: Optional tensor of position IDs (for sample packed data).
batch: A dictionary with a batch of data. May or may not contain `position_ids`
data; if not, we compute it.
"""
from ring_flash_attn import update_ring_flash_attn_params
input_ids = batch["input_ids"]
position_ids = batch.get("position_ids")
if position_ids is None:
seq_len = input_ids.shape[1]
position_ids = torch.arange(
0, seq_len, dtype=torch.long, device=input_ids.device
).unsqueeze(0)
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -1,11 +0,0 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -1,192 +0,0 @@
"""
HuggingFace flash attention adapter for basic ring attention (batch API).
Inspired by
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
Our implementation closely follows the structure of that module, but we've minified it
somewhat to support only the latest versions of transformers.
"""
# pylint: disable=protected-access,cyclic-import
import os
from typing import Callable
import torch
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc
RING_ATTN_FUNC_MAPPING = {
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
}
def create_flash_attn_forward(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
Create a ring flash attention forward function compatible with HuggingFace's
interface.
Args:
process_group: A PyTorch distributed process group.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
Returns:
A function that implements the ring flash attention forward pass with the
signature expected by HuggingFace Transformers.
"""
# transformers 4.48+
# pylint: disable=unused-argument
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: torch.Tensor | None = None,
softmax_scale: float | None = None,
sliding_window: int | None = None,
use_top_left_mask: bool = False,
softcap: float | None = None,
deterministic: bool = None,
cu_seq_lens_q: torch.LongTensor | None = None,
cu_seq_lens_k: torch.LongTensor | None = None,
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
**kwargs,
):
"""
Calls the forward method of Ring Flash Attention.
Args:
query_states: Tensor containing the query vectors.
key_states: Tensor containing the key vectors.
value_states: Tensor containing the value vectors.
attention_mask: Not used in this implementation.
query_length: Integer representing the length of the query sequence.
is_causal: Boolean indicating whether to apply a causal mask to the attention.
dropout: Float representing the dropout probability. Default is 0.0.
position_ids: Not used in this implementation.
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
sliding_window: Optional integer defining the size of the sliding attention window.
Default is None.
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
Default is False.
softcap: Not used in this implementation.
deterministic: Optional boolean to enforce deterministic computation. Default is None.
cu_seq_lens_q: Not used in this implementation.
cu_seq_lens_k: Not used in this implementation.
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
torch.Tensor: The output of the attention mechanism, with shape
`[batch_size, query_length, num_heads, head_dim]`.
"""
if not use_top_left_mask:
causal = is_causal
else:
causal = is_causal and query_length != 1
# Handle sliding window
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
window_size = (
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
)
# Handle deterministic mode
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = (
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
)
# Call ring flash attention function
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
query_states,
key_states,
value_states,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
group=process_group,
)
return attn_output
return _flash_attention_forward
def substitute_hf_flash_attn(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
):
"""
Substitute HuggingFace's flash attention implementation with ring-based implementation.
Args:
process_group: PyTorch distributed process group for communication.
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
attention with.
"""
try:
# Substitute flash attention
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
process_group=process_group, ring_attn_func=ring_attn_func
)
if check_params(old_flash_attention_forward, new_flash_attention_forward):
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
else:
raise ValueError(
"The signature of the new flash attention forward function does not match the old one."
)
except Exception as exception:
raise ValueError(
f"The current transformer version {transformers.__version__} is not supported. "
"Please use pip install -U transformers to upgrade to the latest version. "
"If the code failed with the latest version, "
f"please file an issue."
) from exception
# Register with ALL_ATTENTION_FUNCTIONS if available
if ALL_ATTENTION_FUNCTIONS is not None:
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

View File

@@ -93,20 +93,9 @@ def patch_llama4_linearized_modeling():
"""
from transformers.models.llama4 import modeling_llama4
old_lamma_4_text_experts = modeling_llama4.Llama4TextExperts
modeling_llama4.Llama4TextExperts = Llama4TextExperts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
Llama4TextExperts,
)
def unpatch():
modeling_llama4.Llama4TextExperts = old_lamma_4_text_experts
setattr(
sys.modules["transformers.models.llama4"],
"Llama4TextExperts",
old_lamma_4_text_experts,
)
return unpatch

View File

@@ -31,8 +31,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2",
"deepseek_v2",
"deepseek_v3",
"glm",
"glm4",
]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch)
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -1,78 +0,0 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -6,7 +6,6 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict
@@ -26,15 +25,11 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
try:
@@ -86,11 +81,6 @@ def setup_model_and_tokenizer(
# Apply freezing if specified
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
if any(
any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters
):
model.enable_input_require_grads()
return model, tokenizer, peft_config, processor
@@ -109,7 +99,7 @@ def setup_reference_model(
Reference model if needed for RL training, `None` otherwise.
"""
model_ref = None
if cfg.rl and cfg.rl != RLType.ORPO:
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
@@ -190,28 +180,16 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
LOG.info("Starting trainer...")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -1,12 +1,19 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
"""
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
@dataclass
class DataCollatorForSeq2Seq:
@@ -41,16 +48,28 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
model: Any | None = None
padding: bool | str | PaddingStrategy = True
max_length: int | None = None
pad_to_multiple_of: int | None = None
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
@@ -120,8 +139,40 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].shape[1]
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for ring attention calculation
update_ring_attn_params(batch=batch)
# Slice batch for sequence parallel processing
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
batch[key] = batch[key][:, start:end]
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,6 +126,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -18,9 +18,8 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
LOG = logging.getLogger("axolotl")
def _get_path(ds_hash, cfg):
@@ -81,7 +80,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -101,7 +100,7 @@ def drop_long_rl_seq(
len_prompt + len_rejected
) <= sequence_len
if rl is RLType.KTO:
if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -115,7 +114,7 @@ def drop_long_rl_seq(
return (len_prompt + len_completion) <= sequence_len
if rl is RLType.GRPO:
if rl == "grpo":
return True
raise ValueError("Unknown RL type")
@@ -138,9 +137,9 @@ def load_prepare_preference_datasets(cfg):
if _type:
if isinstance(_type, DictDefault):
_type = "user_defined.default"
if _cfg.rl is RLType.ORPO:
if _cfg.rl == "orpo":
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
elif _cfg.rl is RLType.KTO:
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
@@ -151,7 +150,7 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):

View File

@@ -3,7 +3,6 @@
import functools
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -118,27 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
# other ranks, we just need to present a fake dataset
if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
@@ -351,23 +332,16 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
if isinstance(dataset, IterableDataset):
num_workers = cfg.dataset_processes
def gen_from_iter_ds(_ds, worker_id: List[int], num_workers: List[int]):
"""Generator function to correctly splice the dataset for each worker"""
for i, item in enumerate(_ds):
if i % num_workers[0] == worker_id[0]:
yield item
def gen_from_iter_ds(_ds, _=None):
yield from _ds
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=num_workers,
num_proc=cfg.dataset_processes,
split=split,
gen_kwargs={
"worker_id": list(range(num_workers)),
"num_workers": [num_workers] * num_workers,
},
gen_kwargs={"_": list(range(cfg.dataset_processes))},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:

View File

@@ -2,14 +2,13 @@
module to freeze/unfreeze parameters by name
"""
import logging
import re
from typing import Callable, List, Tuple, Union
from accelerate.logging import get_logger
from axolotl.utils.distributed import is_main_process
LOG = get_logger(__name__)
LOG = logging.getLogger("axolotl.utils.freeze")
def freeze_layers_except(model, regex_patterns):
@@ -185,7 +184,7 @@ class LayerNamePattern:
"""
self.raw_pattern = pattern
name_pattern, self.range = self._parse_pattern(pattern)
self.name_regex = re.compile(re.sub(r"\.(?!\+)", "\\.", name_pattern))
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
def match(self, name: str) -> bool:
"""

View File

@@ -72,7 +72,6 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
@@ -543,17 +542,6 @@ class ModelLoader:
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "llama4":
@@ -656,7 +644,6 @@ class ModelLoader:
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
)
def patch_attention(self) -> None:
@@ -918,6 +905,13 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
patch_flex_wrapper()
patch_flex_make_mask()
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
@@ -1121,7 +1115,7 @@ class ModelLoader:
return skip_move_to_device
def adjust_model_config(self) -> None:
def ajust_model_config(self) -> None:
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "max_position_embeddings")
@@ -1281,7 +1275,7 @@ class ModelLoader:
else:
self.model.tie_weights()
self.adjust_model_config()
self.ajust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in (
@@ -1341,7 +1335,7 @@ class ModelLoader:
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if (
self.cfg.adapter
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and self.cfg.rl in ["dpo", "ipo", "kto"]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = max(last_step - 1, 0)
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:

View File

@@ -18,7 +18,6 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
@@ -28,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.enums import ChatTemplate, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -226,7 +225,6 @@ class AxolotlInputConfig(
sdp_attention: bool | None = None
s2_attention: bool | None = None
flex_attention: bool | None = None
flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = None
flash_attn_cross_entropy: bool | None = None
flash_attn_rms_norm: bool | None = None
@@ -260,7 +258,6 @@ class AxolotlInputConfig(
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
ring_attn_func: RingAttnFunc | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -661,7 +658,6 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"
@@ -719,10 +715,9 @@ class AxolotlInputConfig(
and data.get("eval_sample_packing") is None
and not data.get("eval_table_size")
):
if is_main_process():
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True
if (
@@ -784,7 +779,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl is RLType.SIMPO and self.warmup_ratio:
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
@@ -1151,19 +1146,21 @@ class AxolotlInputConfig(
return data
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
if not info.data["micro_batch_size"] == 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"micro_batch_size must be set to 1 "
"due to a `ring-flash-attn` requirement"
)
@@ -1179,41 +1176,16 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
if is_main_process():
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
raise ValueError(
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = getattr(self, "sample_packing", False)
self.ring_attn_func = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self
return value
@model_validator(mode="before")
@classmethod
@@ -1304,14 +1276,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_fsdp2 = (
data.get("fsdp_config") is not None
and str(data.get("fsdp_config").get("fsdp_version")) == "2"
)
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
if capabilities and capabilities.get("n_gpu", 0) > 1:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
)
return data

View File

@@ -6,12 +6,12 @@ from enum import Enum
class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -53,14 +53,3 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
BATCH_ZIGZAG = "batch_zigzag"
BATCH_STRIPE = "batch_stripe"

View File

@@ -36,11 +36,3 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -17,7 +17,6 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -236,8 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
if cfg.model_config_type in ["mamba", "gemma3"]:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -348,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing:
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -627,12 +625,6 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
if (
cfg.torch_compile
and cfg.fsdp_config
and str(cfg.fsdp_config.fsdp_version) == "2"
):
patch_evaluation_loop_for_fsdp2()
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref

View File

@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@@ -56,7 +56,6 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -102,7 +101,6 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -131,7 +129,6 @@ class TestCutCrossEntropyIntegration:
attention_type: True,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,7 +54,6 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -101,7 +100,6 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -1,2 +0,0 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,20 +49,18 @@ class TestPackedFlex:
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 2,
"max_steps": 5,
"use_tensorboard": True,
"save_strategy": "no",
}

View File

@@ -177,7 +177,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
cfg.base_model,
@@ -265,7 +264,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
cfg.base_model,

View File

@@ -621,6 +621,12 @@ class TestMultiGPULlama:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
# TODO: remove skip once deepspeed regression is fixed
# see https://github.com/huggingface/transformers/pull/37324
@pytest.mark.skipif(
transformers_version_eq("4.51.0"),
reason="zero3 is not supported with transformers==4.51.0",
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],

View File

@@ -3,14 +3,13 @@
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard
from ..utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"
@@ -18,15 +17,8 @@ os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
def _run_sequence_parallel_test(
self,
temp_dir,
sample_packing=True,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
):
"""Helper method to run sequence parallel tests with different configurations"""
def test_sequence_parallel_training(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -35,9 +27,9 @@ class TestSequenceParallelism:
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": sample_packing,
"eval_sample_packing": sample_packing,
"pad_to_sequence_len": pad_to_sequence_len,
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -53,7 +45,7 @@ class TestSequenceParallelism:
],
"num_epochs": 1,
"max_steps": 8,
"micro_batch_size": micro_batch_size,
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -69,7 +61,6 @@ class TestSequenceParallelism:
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -95,35 +86,3 @@ class TestSequenceParallelism:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
self,
temp_dir,
sample_packing,
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
)

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda:0"
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
)
# Apply LoRA configuration

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,7 +60,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,7 +104,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,7 +103,6 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -67,7 +67,6 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -65,7 +65,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,7 +105,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -70,7 +70,6 @@ class TestLoraLlama(unittest.TestCase):
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -121,7 +120,6 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,7 +104,6 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -1,4 +1,6 @@
"""E2E tests for mixtral"""
"""
E2E tests for mixtral
"""
import logging
import os
@@ -7,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -58,7 +60,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -97,7 +98,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -47,7 +47,6 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
@@ -80,7 +79,6 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -83,7 +82,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True,
"flash_attention": True,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"load_in_8bit": False,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -115,7 +114,6 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
@@ -46,9 +46,8 @@ class TestResumeLlama:
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 2,
@@ -68,7 +67,6 @@ class TestResumeLlama:
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,22 +2,17 @@
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
@pytest.fixture
@@ -52,27 +47,6 @@ def fixture_cfg():
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@@ -99,6 +73,8 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
@@ -106,11 +82,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
@@ -122,308 +94,88 @@ class TestRingAttention:
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
# Mock the is_main_process function to return True
monkeypatch.setattr(
"axolotl.utils.schemas.config.is_main_process", lambda: True
)
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
assert torch.all(result["input_ids"] == expected_input_ids)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
MagicMock,
)
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -72,7 +72,6 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -123,7 +122,6 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -179,7 +177,6 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -102,7 +102,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,7 +109,6 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

Some files were not shown because too many files have changed in this diff Show More