Compare commits

..

13 Commits

Author SHA1 Message Date
Sung Ching Liu
f8e92407ff Update src/axolotl/common/datasets.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sung Ching Liu
c12906134d Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
8154d26614 nit 2025-04-17 09:47:14 -04:00
Sunny Liu
fefcbc300d barebone-ify the test so we get rid of unneeded processes 2025-04-17 09:47:14 -04:00
Sunny Liu
7d479348ee custom reward function loading, proeprly done 2025-04-17 09:47:14 -04:00
bursteratom
ce0259db13 add outputdir 2025-04-17 09:47:14 -04:00
Sung Ching Liu
2798817cf9 Update tests/e2e/solo/test_grpo.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-17 09:47:14 -04:00
Sunny Liu
0e1b081e49 add unit test 2025-04-17 09:47:14 -04:00
Sunny Liu
8df37ad91f propoer import from file_path after all else fails 2025-04-17 09:47:14 -04:00
Sung Ching Liu
9b74298328 Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
ae8738aa87 skip check_datasets_label during debug for grpo 2025-04-17 09:47:14 -04:00
Sunny Liu
ec52561a0c import from filepath if can't import_module 2025-04-17 09:47:14 -04:00
Sunny Liu
eadb16c709 test import-wihtin-import relative path 2025-04-17 09:47:14 -04:00
84 changed files with 494 additions and 1788 deletions

View File

@@ -46,18 +46,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""

View File

@@ -24,18 +24,13 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -98,11 +93,6 @@ jobs:
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -148,7 +138,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.4.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py' - 'setup.py'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml' - '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday - cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -43,14 +42,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
@@ -75,7 +67,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.multigpu modal run cicd.multigpu

View File

@@ -147,7 +147,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2 max-parallel: 2
matrix: matrix:
python_version: ["3.11"] python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"] pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -109,7 +109,6 @@ jobs:
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v5 uses: codecov/codecov-action@v5
with: with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }} flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false fail_ci_if_error: false
@@ -242,7 +241,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests
@@ -258,12 +256,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -275,13 +267,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -302,7 +288,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal - name: Run tests job on Modal
run: | run: |
modal run cicd.e2e_tests modal run cicd.e2e_tests

View File

@@ -9,7 +9,8 @@ pytest -v --durations=10 -n8 \
--ignore=tests/patched/ \ --ignore=tests/patched/ \
--ignore=tests/cli \ --ignore=tests/cli \
/workspace/axolotl/tests/ \ /workspace/axolotl/tests/ \
--cov=axolotl --cov=axolotl \
--cov-report=xml:coverage.xml
# Run lora kernels tests with coverage append # Run lora kernels tests with coverage append
pytest -v --durations=10 \ pytest -v --durations=10 \
@@ -50,6 +51,11 @@ pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/ \ /workspace/axolotl/tests/e2e/ \
--cov=axolotl \ --cov=axolotl \
--cov-append \ --cov-append \
--cov-report=xml:e2e-coverage.xml --cov-report=xml:coverage.xml
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true # Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -28,7 +28,6 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
} }

View File

@@ -29,7 +29,6 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub", "HF_HOME": "/workspace/data/huggingface-cache/hub",
} }

View File

@@ -1,23 +1,25 @@
#!/bin/bash #!/bin/bash
set -e set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \ --cov=axolotl \
--cov-append --cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \ --cov=axolotl \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi

View File

@@ -1,7 +1,5 @@
codecov: codecov:
require_ci_to_pass: yes require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage: coverage:
precision: 2 precision: 2
@@ -51,6 +49,3 @@ comment:
require_changes: no require_changes: no
require_base: no require_base: no
require_head: yes require_head: yes
github_checks:
annotations: false

View File

@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,17 +199,6 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details. See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage ## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI: While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -49,8 +49,7 @@ sections = [
("Knowledge Distillation (KD)", "kd"), ("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"), ("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"), ("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum"), ("Spectrum", "spectrum")
("LLMCompressor", "llm_compressor")
] ]
for section_name, folder_name in sections: for section_name, folder_name in sections:

View File

@@ -28,8 +28,6 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples: Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0` - `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1` - `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1` - `main-base-py3.11-cu124-2.4.1`
@@ -52,7 +50,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main # on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version} main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4) # latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
main-latest main-latest
# nightly build # nightly build
@@ -70,7 +68,6 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples: Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0` - `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1` - `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1` - `main-py3.11-cu124-2.4.1`

View File

@@ -19,12 +19,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods} ## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi} ### PyPI Installation (Recommended) {#sec-pypi}
```{.bash} ```{.bash}

View File

@@ -1,62 +0,0 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,77 +0,0 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -26,11 +26,3 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
### Llama 4 Maverick 17Bx128Experts (400B) ### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon Coming Soon
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```

View File

@@ -10,6 +10,7 @@ plugins:
liger_glu_activation: true liger_glu_activation: true
liger_rms_norm: true liger_rms_norm: true
liger_layer_norm: true liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true load_in_4bit: true

View File

@@ -1,5 +1,4 @@
codecov codecov
codecov-cli
pytest pytest
pytest-cov pytest-cov
pytest-retry pytest-retry

View File

@@ -6,20 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.8 liger-kernel==0.5.6
# END section # END section
packaging==23.2 packaging==23.2
peft==0.15.2 peft==0.15.1
transformers==4.51.3 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0
deepspeed>=0.15.4 deepspeed>=0.15.4
trl==0.17.0 trl==0.16.1
hf_xet==1.0.0 hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try: try:
torch_version = version("torch") torch_version = version("torch")
except PackageNotFoundError: except PackageNotFoundError:
torch_version = "2.6.0" # default to torch 2.6 torch_version = "2.5.1"
_install_requires.append(f"torch=={torch_version}") _install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version) version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,16 +64,10 @@ def parse_requirements(extras_require_map):
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 7): if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0 _install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.4"] extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
if patch == 0: if patch == 0:
@@ -149,9 +143,6 @@ extras_require = {
"vllm": [ "vllm": [
"vllm==0.7.2", "vllm==0.7.2",
], ],
"llmcompressor": [
"llmcompressor==0.5.1",
],
} }
install_requires, dependency_links, extras_require_build = parse_requirements( install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs: class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command.""" """Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: Optional[int] = field( tensor_parallel_size: int = field(
default=None, default=1,
metadata={"help": "Number of tensor parallel workers to use."}, metadata={"help": "Number of tensor parallel workers to use."},
) )
host: Optional[str] = field( host: str = field(
default=None, # nosec B104 default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."}, metadata={"help": "Host address to run the server on."},
) )
port: Optional[int] = field( port: int = field(
default=None, default=8000,
metadata={"help": "Port to run the server on."}, metadata={"help": "Port to run the server on."},
) )
gpu_memory_utilization: Optional[float] = field( gpu_memory_utilization: Optional[float] = field(

View File

@@ -129,17 +129,19 @@ def load_preference_datasets(
total_num_steps = None total_num_steps = None
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") if not cfg.rl == "grpo":
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples, check_dataset_labels(
tokenizer, train_samples,
num_examples=cli_args.debug_num_examples, tokenizer,
text_only=cli_args.debug_text_only, num_examples=cli_args.debug_num_examples,
rl_mode=True, text_only=cli_args.debug_text_only,
) rl_mode=True,
)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,

View File

@@ -932,6 +932,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator( return collator(
*collator_args, *collator_args,
@@ -1037,20 +1040,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None: if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
elif self.cfg.rl_beta is not None: if self.cfg.orpo_alpha:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None training_args_cls = None
blocklist_args_kwargs = [] blocklist_args_kwargs = []
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
@@ -1121,12 +1119,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs, **training_args_kwargs,
) )
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args return training_args
def build(self, total_num_steps): def build(self, total_num_steps):

View File

@@ -371,15 +371,13 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
loss = super().compute_loss( return super().compute_loss(
model, model,
inputs, inputs,
return_outputs=return_outputs, return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
return loss
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm: if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
if trl.vllm_server_timeout: if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex: if trl.vllm_guided_decoding_regex:
@@ -135,9 +135,7 @@ class GRPOStrategy:
try: try:
# use importlib to dynamically load the reward function from the module # use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1] reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module( reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
".".join(reward_func_fqn.split(".")[:-1])
)
reward_func = getattr(reward_func_module, reward_func_module_name) reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2: if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError( raise ValueError(

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,86 +1,16 @@
""" """Module for Axolotl trainer sequence parallelism mixin"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
import functools
import logging import logging
import torch
import torch.distributed as dist import torch.distributed as dist
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin: class SequenceParallelMixin:
""" """
Mixin class for sequence parallelism support in trainers. Mixin class for sequence parallelism support in trainers.
@@ -157,157 +87,3 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler( return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True eval_dataset, shuffle=False, is_eval=True
) )
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -27,6 +27,8 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml ```yaml
plugins: plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
``` ```
## Supported Models ## Supported Models
@@ -45,8 +47,6 @@ plugins:
- qwen2 - qwen2
- cohere - cohere
- cohere2 - cohere2
- glm
- glm4
## Citation ## Citation

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy. Input args for Cut Cross Entropy.
""" """
cut_cross_entropy: Optional[bool] = True cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -1,57 +0,0 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -20,10 +20,6 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3, patch_gemma3,
patch_gemma3_text, patch_gemma3_text,
) )
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4, patch_llama4,
patch_llama4_text, patch_llama4_text,
@@ -49,8 +45,6 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"qwen2": patch_qwen2, "qwen2": patch_qwen2,
"cohere": patch_cohere, "cohere": patch_cohere,
"cohere2": patch_cohere2, "cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
} }

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2 - deepseek_v2
- gemma - gemma
- gemma2 - gemma2
- gemma3 - gemma3 (partial support, no support for FLCE yet)
- granite - granite
- jamba - jamba
- llama - llama

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
) )
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4": elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import ( from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4, apply_liger_kernel_to_llama4,

View File

@@ -1,108 +0,0 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -1,5 +0,0 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -1,40 +0,0 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -1,171 +0,0 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -1,40 +0,0 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -31,8 +31,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2", "starcoder2",
"deepseek_v2", "deepseek_v2",
"deepseek_v3", "deepseek_v3",
"glm",
"glm4",
] ]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch) super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> float: def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -4,30 +4,73 @@ module for base dataset transform strategies
import importlib import importlib
import logging import logging
import sys
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def import_from_path(module_name: str, file_path: str):
"""
Import a module from a file path.
Args:
module_name: Name of the module.
file_path: Path to the file.
Returns:
module: The imported module.
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Could not create module spec for: {file_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
spec.loader = loader
loader.exec_module(module)
return module
def load(strategy, cfg, module_base=None, **kwargs): def load(strategy, cfg, module_base=None, **kwargs):
try: if len(strategy.split(".")) == 1:
if len(strategy.split(".")) == 1: strategy = strategy + ".default"
strategy = strategy + ".default" load_fn = strategy.split(".")[-1]
load_fn = strategy.split(".")[-1] func = None
if len(strategy.split(".")) > 1: if len(strategy.split(".")) > 1:
try: try:
importlib.import_module( mod = importlib.import_module(
strategy.split(".")[-2], strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]), ".".join(strategy.split(".")[:-2]),
) )
module_base = ".".join(strategy.split(".")[:-2]) func = getattr(mod, load_fn)
strategy = strategy.split(".")[-2] return func(cfg, **kwargs)
except ModuleNotFoundError: except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1]) pass
else:
strategy = "." + ".".join(strategy.split(".")[:-1]) try:
mod = importlib.import_module(
"." + ".".join(strategy.split(".")[:-1]), module_base
)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except ModuleNotFoundError:
pass
try:
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
module_name = strategy.split(".")[-2]
mod = import_from_path(module_name, file_path)
func = getattr(mod, load_fn)
if func is not None:
return func(cfg, **kwargs)
except FileNotFoundError:
pass
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base) mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn) func = getattr(mod, load_fn)
return func(cfg, **kwargs) return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}") LOG.warning(f"unable to load strategy {strategy}")
return None return func

View File

@@ -6,7 +6,6 @@ import os
import signal import signal
import sys import sys
import weakref import weakref
from contextlib import nullcontext
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
@@ -26,9 +25,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
@@ -189,28 +185,16 @@ def execute_training(
trainer: The configured trainer object. trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable. resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
""" """
# Define the context managers to use LOG.info("Starting trainer...")
flash_context = ( if cfg.flash_optimum:
torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True, enable_flash=True,
enable_math=True, enable_math=True,
enable_mem_efficient=True, enable_mem_efficient=True,
) ):
if cfg.flash_optimum trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else nullcontext() else:
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
@@ -295,23 +279,8 @@ def save_trained_model(
trainer.model.save_pretrained( trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization cfg.output_dir, safe_serialization=safe_serialization
) )
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
def create_model_card(cfg: DictDefault, trainer: Trainer): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """

View File

@@ -1,12 +1,20 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences""" """
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import numpy as np import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass @dataclass
class DataCollatorForSeq2Seq: class DataCollatorForSeq2Seq:
@@ -41,6 +49,8 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`): return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf". The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
""" """
tokenizer: PreTrainedTokenizerBase tokenizer: PreTrainedTokenizerBase
@@ -51,6 +61,17 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None): def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys() has_attn_mask = "attention_mask" in features[0].keys()
@@ -120,8 +141,62 @@ class DataCollatorForSeq2Seq:
) )
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
@dataclass @dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,6 +126,9 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f: with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f) cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch: if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step if save_steps < 1.0: # prevent saves on every step

View File

@@ -3,7 +3,6 @@
import functools import functools
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -118,26 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from iter_ds = load_dataset(
# other ranks, we just need to present a fake dataset path, streaming=True, split=split, name=name, data_files=data_files
if ( )
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip) iter_ds = iter_ds.skip(skip)

View File

@@ -1,7 +1,5 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )
@@ -11,10 +9,6 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply( return Unsloth_Offloaded_Gradient_Checkpointer.apply(
( decoder_layer.__self__,
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args, *args,
) )

View File

@@ -139,22 +139,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
and model_config.quantization_config and model_config.quantization_config
) )
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = ( quant_config_method_is_gptq = (
quant_config_exists quant_config_exists
and "quant_method" in model_config.quantization_config and "quant_method" in model_config.quantization_config

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr self.max_lr = max_lr
self.total_steps = total_steps self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps self.num_warmup_steps = num_warmup_steps
self.last_step = max(last_step - 1, 0) self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@@ -660,7 +660,6 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0 data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
): ):
raise ValueError( raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0" "eval_steps and eval_strategy are not supported with val_set_size == 0"
@@ -1149,17 +1148,22 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="after") @field_validator("sequence_parallel_degree", mode="after")
def check_sequence_parallel_degree(self): @classmethod
if not self.sequence_parallel_degree: def check_sequence_parallel_degree(cls, value, info):
self.sequence_parallel_degree = 1 if not value:
elif self.sequence_parallel_degree > 1: value = 1
if not self.flash_attention:
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled" "micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
@@ -1179,40 +1183,42 @@ class AxolotlInputConfig(
# according to the proportion of non-padding tokens per rank. # according to the proportion of non-padding tokens per rank.
LOG.warning( LOG.warning(
"Sequence parallelism (SP) is enabled with " "Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. " f"sequence_parallel_degree={value}. Please note that logged losses may "
"Please note that logged losses may differ slightly to the non-SP " "differ slightly to the non-SP losses due to transformers Trainer "
"losses due to transformers Trainer implementation details. " "implementation details. Please see "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details." "for more details."
) )
return self return value
@model_validator(mode="after") @field_validator("ring_attn_func", mode="after")
def validate_ring_attn_func(self): @classmethod
if getattr(self, "sequence_parallel_degree", 1) == 1: def check_ring_attn_func(cls, value, info):
return self if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if self.ring_attn_func is not None: if value is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc) valid_funcs = list(RingAttnFunc)
if self.ring_attn_func in valid_funcs: if value in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func) value = RingAttnFunc(value)
else: else:
raise ValueError( raise ValueError(
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}" f"ring_attn_func: {value} must be one of {valid_funcs}"
) )
else: else:
# Default ring attention function selection # Default ring attention function selection
sample_packing = getattr(self, "sample_packing", False) sample_packing = info.data.get("sample_packing")
self.ring_attn_func = ( value = (
RingAttnFunc.VARLEN_LLAMA3 RingAttnFunc.VARLEN_LLAMA3
if sample_packing if sample_packing
else RingAttnFunc.BATCH_RING else RingAttnFunc.BATCH_RING
) )
return self return value
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -36,11 +36,3 @@ class VllmConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"}, json_schema_extra={"description": "Enable prefix caching for VLLM"},
) )
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
if cfg.eval_sample_packing: if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
@@ -528,13 +528,6 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None): def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage: if stage:

View File

@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture(): def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists from ..utils import check_model_output_exists
@@ -56,7 +56,6 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir): def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg) cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -102,7 +101,6 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -131,7 +129,6 @@ class TestCutCrossEntropyIntegration:
attention_type: True, attention_type: True,
} }
) )
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1 from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,7 +54,6 @@ class LigerIntegrationTestCase:
} }
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
@@ -101,7 +100,6 @@ class LigerIntegrationTestCase:
} }
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg) prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()

View File

@@ -1,106 +0,0 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import (
check_model_output_exists,
require_llmcompressor,
require_torch_2_4_1,
)
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
@require_llmcompressor
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
if save_compressed:
assert (Path(temp_dir) / "recipe.yaml").exists()
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -1,2 +0,0 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,9 +49,8 @@ class TestPackedFlex:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -4,14 +4,11 @@ GRPO test suite
import os import os
import random import random
import shutil
import subprocess # nosec B404 import subprocess # nosec B404
import sys import sys
import tempfile
import time import time
from pathlib import Path from pathlib import Path
import psutil
import pytest import pytest
import requests import requests
import yaml import yaml
@@ -24,8 +21,8 @@ from tests.e2e.utils import require_vllm
def start_vllm( def start_vllm(
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> subprocess.Popen: ) -> int:
""" """
helper function to start the VLLM server in the background, mostly for testing purposes helper function to start the VLLM server in the background, mostly for testing purposes
""" """
@@ -49,41 +46,10 @@ def start_vllm(
# print out the command to be executed # print out the command to be executed
print(" ".join(cmd)) print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id # start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with process = subprocess.Popen( # pylint: disable=consider-using-with
cmd, cmd,
env=cmd_env, env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE, stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE, stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603 ) # nosec B603
@@ -92,51 +58,32 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})") print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds # wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False started = False
if wait and host and port: if wait and host and port:
for i in range(0, int(wait), period_seconds): for _ in range(int(wait)):
try: try:
response = requests.get(f"http://{host}:{port}", timeout=1) response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]: if int(response.status_code) in [200, 404]:
started = True started = True
break break
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException:
print(f"{i}: VLLM server failed to start: {str(exc)}") pass
# also check if the process.pid is still running # also check if the process.pid is still running
if not process.poll() is None: if not process.poll() is None:
break break
time.sleep(period_seconds) time.sleep(1)
if wait and not started: if wait and not started:
print( print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs." f"VLLM server process did not start within {wait} seconds. Please check your server logs."
) )
recursive_kill(process) process.kill()
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process # return the process id
return process return process.pid
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO: class TestGRPO:
@@ -227,17 +174,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", "NCCL_P2P_LEVEL": "LOC",
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -256,14 +202,10 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_gpus", "num_gpus",
@@ -320,17 +262,16 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy() current_env = os.environ.copy()
env = { env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env, **current_env,
"CUDA_VISIBLE_DEVICES": "1", "CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1", "VLLM_USE_V1": "0",
# "VLLM_USE_V1": "0",
} }
vllm_process = start_vllm( vllm_process_id = start_vllm(
cfg.base_model, cfg.base_model,
env=env, env=env,
quiet=True, quiet=True,
wait=300, wait=120,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len, max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching, enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -349,11 +290,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
], ],
env={ env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
) )
finally: finally:
recursive_kill(vllm_process) os.kill(vllm_process_id, 9)

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard from ..utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -60,7 +60,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True, "fp16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,7 +104,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True, "fp16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -1,77 +0,0 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,7 +103,6 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -67,7 +67,6 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True cfg.bf16 = True
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -65,7 +65,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,7 +105,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -70,7 +70,6 @@ class TestLoraLlama(unittest.TestCase):
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -121,7 +120,6 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestMistral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,7 +104,6 @@ class TestMistral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -60,7 +60,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -99,7 +98,6 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers import transformers
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -47,7 +47,6 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10, "eval_steps": 10,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)
@@ -80,7 +79,6 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10, "eval_steps": 10,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
@@ -63,7 +63,6 @@ class TestPhiMultipack(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -83,7 +82,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"load_in_4bit": True, "load_in_8bit": False,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 64, "lora_r": 64,
"lora_alpha": 32, "lora_alpha": 32,
@@ -115,7 +114,6 @@ class TestPhiMultipack(unittest.TestCase):
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir from ..utils import check_model_output_exists, most_recent_subdir
@@ -46,9 +46,8 @@ class TestResumeLlama:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 2, "num_epochs": 2,
@@ -68,7 +67,6 @@ class TestResumeLlama:
cfg.bf16 = True cfg.bf16 = True
else: else:
cfg.fp16 = True cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,19 +2,14 @@
# pylint: disable=redefined-outer-name,unused-argument # pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from accelerate.state import PartialState from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group, get_ring_attn_group,
register_ring_attn,
set_ring_attn_group, set_ring_attn_group,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -52,27 +47,6 @@ def fixture_cfg():
return cfg return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention: class TestRingAttention:
"""Tests for the ring attention functionality.""" """Tests for the ring attention functionality."""
@@ -99,6 +73,11 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state self, mock_world_size, mock_rank, mock_new_group, partial_state
): ):
"""Test that ring attention groups are created correctly.""" """Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks # Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3 mock_rank.return_value = 3 # GPU #3
@@ -122,303 +101,88 @@ class TestRingAttention:
set_ring_attn_group(None) set_ring_attn_group(None)
class TestConfigValidation: # Mock a simplified DataCollator test
"""Tests for validating sequence parallelism configurations.""" @patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
@pytest.fixture(autouse=True) # Create a sample batch
def setup_mocks(self, monkeypatch): batch = {
"""Set up mocks for all tests in this class.""" "input_ids": torch.tensor(
# Mock the ring_flash_attn module [
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
@pytest.fixture # Simplified slicing logic from SequenceParallelDataCollator
def base_cfg(self): def slice_batch(batch, rank, world_size):
"""Create a base configuration for testing.""" result = {}
return DictDefault( for key in batch:
{ seq_len = batch[key].shape[1]
"base_model": "HuggingFaceTB/SmolLM2-135M", slice_size = seq_len // world_size
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], start_idx = rank * slice_size
"micro_batch_size": 1, end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
"gradient_accumulation_steps": 1, result[key] = batch[key][:, start_idx:end_idx]
"learning_rate": 1e-3, return result
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize( # Slice the batch
"config_updates, expected_values, should_pass, error_msg", result = slice_batch(
[ batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
) )
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config # Check slicing
cfg = base_cfg assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
cfg.update(config_updates) expected_input_ids = torch.tensor(
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[ [
(None, True, RingAttnFunc.VARLEN_LLAMA3), [104, 105, 106], # Second slice of first sequence
(None, False, RingAttnFunc.BATCH_RING), [204, 205, 206], # Second slice of second sequence
], ]
ids=["default_with_sample_packing", "default_without_sample_packing"],
) )
def test_ring_attn_func_validation( assert torch.all(result["input_ids"] == expected_input_ids)
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
class TestApplySequenceParallelism: @patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
"""Tests for the apply_sequence_parallelism function.""" def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
@pytest.fixture(autouse=True) # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
def mock_distributed(self, monkeypatch): cfg = cfg | {
"""Mock torch.distributed functions for testing.""" "sequence_parallel_degree": 2,
# Mock is_initialized to return True "flash_attention": True,
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) }
# Mock get_rank to return 0 by default # Should validate without errors
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group def test_config_validation_with_invalid_inputs(cfg):
monkeypatch.setattr( """Test that invalid sequence parallelism configurations fail validation."""
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", from axolotl.utils.schemas.config import AxolotlInputConfig
MagicMock,
)
# Mock update_ring_attn_params # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
monkeypatch.setattr( cfg = cfg | {
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", "sequence_parallel_degree": 2,
lambda **kwargs: None, "flash_attention": False,
) }
def test_world_size_one(self, sequence_parallel_batch): # Should raise ValidationError
"""Test that function returns original batch when world size is 1.""" with pytest.raises(ValueError) as excinfo:
result = apply_sequence_parallelism( AxolotlInputConfig(**cfg)
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged # Verify error message
assert result == sequence_parallel_batch assert "flash_attention: true must be set" in str(excinfo.value)
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard from ..utils import check_model_output_exists, check_tensorboard
@@ -72,7 +72,6 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -123,7 +122,6 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -179,7 +177,6 @@ class TestUnslothQLoRA:
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -0,0 +1,85 @@
"""
E2E tests for preprocessing
"""
import logging
import os
import unittest
import transformers
from axolotl.cli.args import PreprocessCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomRewardFunctionLoading(unittest.TestCase):
"""
Test case for GRPO training using single GPU
"""
def _utils_write_rewards(self):
# write cfg to yaml file
with open("rewards.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@with_temp_dir
def test_custom_rewards_fn_preprocess(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"strict": False,
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
"rewards.rand_reward_func"
], # format: '{file_name}.{fn_name}'
"reward_weights": [1.0],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": "rewards.oai_gsm8k_transform",
},
],
"dataset_prepared_path": temp_dir,
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"learning_rate": 0.000005,
}
)
self._utils_write_rewards()
cfg = validate_config(cfg)
normalize_config(cfg)
parser = transformers.HfArgumentParser(PreprocessCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -102,7 +102,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True, "use_tensorboard": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,7 +109,6 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True, "bf16": True,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -40,9 +40,8 @@ class TestPackedLlama(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"load_in_4bit": True, "load_in_8bit": False,
"adapter": "qlora", "adapter": "qlora",
"lora_r": 64, "lora_r": 64,
"lora_alpha": 32, "lora_alpha": 32,
@@ -111,7 +111,6 @@ class TestPhi(unittest.TestCase):
"bf16": "auto", "bf16": "auto",
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -57,7 +57,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"seed": 42, "seed": 42,
} }
) )
cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,24 +109,6 @@ def require_vllm(test_case):
)(test_case) )(test_case)
def require_llmcompressor(test_case):
"""
Decorator marking a test that requires a llmcompressor to be installed
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires a llmcompressor to be installed"
)(test_case)
def is_hopper(): def is_hopper():
compute_capability = torch.cuda.get_device_capability() compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0) return compute_capability == (9, 0)

View File

@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest import pytest
from datasets import Dataset from datasets import Dataset
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -319,7 +319,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
} }
) )
self.cfg_1 = validate_config(self.cfg_1)
normalize_config(self.cfg_1) normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")