Compare commits

..

30 Commits

Author SHA1 Message Date
Wing Lian
b708a1cc45 validate config to set defaults 2025-04-26 13:11:25 -04:00
Rahul Tuli
daa9a58f83 Add: line about further optimizations using llmcompressor
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 14:06:25 -04:00
Rahul Tuli
ae7069e15b Merge branch 'main' into llmcompressor-sft 2025-04-24 12:37:14 -05:00
Rahul Tuli
20d48cd617 Address Review Comments:
* deleted redundant docs/llm_compressor.qmd
* incorporated feedback in integration README.md
* added llmcompressor integration to docs/custom_integrations.qmd

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
2025-04-24 13:36:09 -04:00
Wing Lian
1447beb132 make sure to validate the config before normalizing so defaults get set (#2554)
* make sure to validate the config before normalizing so defaults get set

* validation not needed for particular test

* remove duplicate validations

* set qlora correctly
2025-04-24 13:01:43 -04:00
Rahul Tuli
e766a730ba Add: .qmd file 2025-04-24 12:45:57 -04:00
Rahul Tuli
7dc797860e Tests, Style, Updates 2025-04-24 12:45:57 -04:00
Rahul Tuli
ff4904c8c4 Rebase and updates! 2025-04-24 12:45:57 -04:00
Rahul Tuli
45b7293793 Add: llm_compressor integration documentation 2025-04-24 12:45:57 -04:00
Rahul Tuli
279c7178bc Move: LLMCompressorPlugin into it's own submodule 2025-04-24 12:45:57 -04:00
Rahul Tuli
e73c3709f9 Update model config 2025-04-24 12:45:57 -04:00
Rahul Tuli
33562189f8 Use: absolute import 2025-04-24 12:45:57 -04:00
Rahul Tuli
c057a2268f Rename: sft.yaml to sparse-finetuning.yaml 2025-04-24 12:45:57 -04:00
Rahul Tuli
9d7a3809b5 Add: llcompressor installable 2025-04-24 12:45:57 -04:00
Rahul Tuli
b7b24d6a64 Address review comments from @markurtz 2025-04-24 12:45:57 -04:00
Rahul Tuli
8b82b8f7a1 Apply suggestions from @markurtz
Co-authored-by: Mark Kurtz <mark.j.kurtz@gmail.com>
2025-04-24 12:45:57 -04:00
Rahul Tuli
81da58c0a1 Update llmcompressor version to latest 2025-04-24 12:45:57 -04:00
Rahul Tuli
2cd5a234a7 Revert: TODO's 2025-04-24 12:45:57 -04:00
Rahul Tuli
8c1af0747d Use: warning over warn 2025-04-24 12:45:57 -04:00
Rahul Tuli
a06b360d99 pre commit hooks 2025-04-24 12:45:57 -04:00
Rahul Tuli
0f6456a14f Add:llmcompressor instalable 2025-04-24 12:45:57 -04:00
Rahul Tuli
47a333ce49 Update: review comments! 2025-04-24 12:45:57 -04:00
Rahul Tuli
f9d6776c28 Add: SFTPlugin with llmcompressor 2025-04-24 12:45:57 -04:00
Dan Saunders
66f41ec6f1 disable codecov pr annotations (#2556) 2025-04-24 08:51:51 -04:00
NanoCode012
85053f4bd4 Fix(doc): add delinearize instruction (#2545)
* fix: mention to install pytorch before axolotl

* feat(doc): include instruction to delinearize

* fix: update instruction for delinearize with adapter
2025-04-24 01:03:43 -04:00
Wing Lian
a4d5112ae1 builds for torch 2.7.0 (#2552)
* builds for torch==2.7.0

* use xformers==0.0.29.post3

* no vllm support with torch 2.7

* update default, fix conditional

* no xformers for 270

* no vllm on 2.7.0 for multigpu test too

* remove deprecated verbose arg from scheduler

* 2.7.0 tests on cpu
2025-04-24 00:39:31 -04:00
Wing Lian
0d691cc2a7 add base docker image with pytorch 2.7.0 and variant for cuda 12.8 (#2551)
* add base docker image with pytorch 2.7.0 and variant for cuda 12.8

* my bash is terrible
2025-04-23 14:59:03 -04:00
Dan Saunders
c4053481ff Codecov fixes / improvements (#2549)
* adding codecov reporting

* random change

* codecov fixes

* adding missing dependency

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-23 10:33:30 -04:00
NanoCode012
a6d28d19b1 feat: add glm and glm4 multipack and cce (#2546)
* feat: add glm and glm4 multipack

* feat: add glm4 example

* feat: add cce for glm
2025-04-23 10:27:51 -04:00
Wing Lian
32e335dd51 fix missing host/port for vllm (#2543)
* fix missing host/port for vllm

* set tensor parallel size so it doesn't always default to cli override
2025-04-22 10:16:48 -04:00
66 changed files with 965 additions and 534 deletions

View File

@@ -46,6 +46,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -31,6 +31,11 @@ jobs:
pytorch: 2.6.0
axolotl_extras: vllm
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -93,6 +98,11 @@ jobs:
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -138,7 +148,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -45,6 +45,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -67,6 +74,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -147,6 +147,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -109,6 +109,7 @@ jobs:
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
@@ -241,6 +242,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -268,6 +270,12 @@ jobs:
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -288,6 +296,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -9,8 +9,7 @@ pytest -v --durations=10 -n8 \
--ignore=tests/patched/ \
--ignore=tests/cli \
/workspace/axolotl/tests/ \
--cov=axolotl \
--cov-report=xml:coverage.xml
--cov=axolotl
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
@@ -51,11 +50,6 @@ pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:coverage.xml
--cov-report=xml:e2e-coverage.xml
# Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}

View File

@@ -28,6 +28,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -29,6 +29,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -6,13 +6,13 @@ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
--cov=axolotl
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
--cov-append
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
@@ -20,8 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -49,3 +49,6 @@ comment:
require_changes: no
require_base: no
require_head: yes
github_checks:
annotations: false

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,6 +199,17 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -55,46 +55,20 @@ overrides_of_model_config:
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model
gptq:
gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit:
load_in_8bit: true
# Use bitsandbytes 4 bit
load_in_4bit:

View File

@@ -49,7 +49,8 @@ sections = [
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum")
("Spectrum", "spectrum"),
("LLMCompressor", "llm_compressor")
]
for section_name, folder_name in sections:

View File

@@ -19,6 +19,12 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}

View File

@@ -0,0 +1,62 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,77 @@
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true

View File

@@ -26,3 +26,11 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```

View File

@@ -1,4 +1,5 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry

View File

@@ -19,10 +19,10 @@ datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2
hf_transfer
hqq==0.2.5
sentencepiece
gradio==5.23.3

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,9 +64,15 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 6):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
@@ -143,6 +149,9 @@ extras_require = {
"vllm": [
"vllm==0.7.2",
],
"llmcompressor": [
"llmcompressor==0.5.1",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
tensor_parallel_size: Optional[int] = field(
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
port: Optional[int] = field(
default=None,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:

View File

@@ -47,6 +47,8 @@ cut_cross_entropy: true
- qwen2
- cohere
- cohere2
- glm
- glm4
## Citation

View File

@@ -0,0 +1,57 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
@@ -45,6 +49,8 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"qwen2": patch_qwen2,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
}

View File

@@ -0,0 +1,108 @@
# LLMCompressor Integration
Fine-tune sparsified models in Axolotl using Neural Magic's [LLMCompressor](https://github.com/vllm-project/llm-compressor).
This integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressor's model compression capabilities with Axolotl's distributed training pipelines, users can efficiently fine-tune sparse models at scale.
It uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.
---
## Requirements
- Axolotl with `llmcompressor` extras:
```bash
pip install "axolotl[llmcompressor]"
```
- Requires `llmcompressor >= 0.5.1`
This will install all necessary dependencies to fine-tune sparsified models using the integration.
---
## Usage
To enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:
```yaml
plugins:
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
llmcompressor:
recipe:
finetuning_stage:
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
save_compressed: true
# ... (other training arguments)
```
This plugin **does not apply pruning or sparsification itself** — it is intended for **fine-tuning models that have already been sparsified**.
Pre-sparsified checkpoints can be:
- Generated using [LLMCompressor](https://github.com/vllm-project/llm-compressor)
- Downloaded from [Neural Magic's Hugging Face page](https://huggingface.co/neuralmagic)
- Any custom LLM with compatible sparsity patterns that you've created yourself
To learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:
[https://github.com/vllm-project/llm-compressor/blob/main/README.md](https://github.com/vllm-project/llm-compressor/blob/main/README.md)
### Storage Optimization with save_compressed
Setting `save_compressed: true` in your configuration enables saving models in a compressed format, which:
- Reduces disk space usage by approximately 40%
- Maintains compatibility with vLLM for accelerated inference
- Maintains compatibility with llmcompressor for further optimization (example: quantization)
This option is highly recommended when working with sparse models to maximize the benefits of model compression.
### Example Config
See [`examples/llama-3/sparse-finetuning.yaml`](examples/llama-3/sparse-finetuning.yaml) for a complete example.
---
## Inference with vLLM
After fine-tuning your sparse model, you can leverage vLLM for efficient inference.
You can also use LLMCompressor to apply additional quantization to your fine-tuned
sparse model before inference for even greater performance benefits.:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM("path/to/your/sparse/model")
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
For more details on vLLM's capabilities and advanced configuration options, see the [official vLLM documentation](https://docs.vllm.ai/).
## Learn More
For details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:
[https://github.com/vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor)

View File

@@ -0,0 +1,5 @@
"""Integration entry point for the LLMCompressor plugin."""
from .plugin import LLMCompressorPlugin
__all__ = ["LLMCompressorPlugin"]

View File

@@ -0,0 +1,40 @@
"""
LLMCompressor and Sparse Finetuning config models.
"""
from typing import Any
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class CompressionArgs(BaseModel):
"""Sparse Finetuning config for LLMCompressor."""
# Typing for recipe is set to Any due to:
# https://github.com/vllm-project/llm-compressor/issues/1319
recipe: Annotated[
Any,
Field(
description="The recipe containing the compression algorithms and hyperparameters to apply."
),
]
save_compressed: Annotated[
bool,
Field(
default=False,
description="Whether to save the compressed model after training.",
),
]
class LLMCompressorArgs(BaseModel):
"""LLMCompressor configuration BaseModel."""
llmcompressor: Annotated[
CompressionArgs,
Field(
description="Arguments enabling compression pathways through the LLM Compressor plugins"
),
]

View File

@@ -0,0 +1,171 @@
"""
Sparse Finetuning plugin for Axolotl — enables handling of sparse neural networks
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
from llmcompressor import active_session, create_session
from llmcompressor.core import callbacks as session_callbacks
from llmcompressor.recipe import Recipe
from torch.nn import Module
from transformers.trainer import Trainer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
class LLMCompressorCallbackHandler(TrainerCallback):
"""
Trainer callback for Sparse Finetuning.
Maintains sparsity patterns during training by applying masks after optimization steps,
ensuring zero-weight updates are canceled out.
"""
def __init__(self, trainer: Trainer, recipe: Any):
"""
Initialize the Sparse Finetuning callback handler.
Args:
trainer (Trainer): Huggingface Trainer instance.
recipe (Recipe | dict): Sparse finetuning recipe to apply.
"""
super().__init__()
self.trainer = trainer
self.recipe = (
Recipe.model_validate(recipe) if not isinstance(recipe, Recipe) else recipe
)
self.original_compute_loss = trainer.compute_loss
self.trainer.compute_loss = compute_loss_wrapper(self.trainer.compute_loss)
create_session()
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of training. Initializes the compression session.
Args:
args (TrainingArguments): Training arguments.
state (TrainerState): Trainer state.
control (TrainerControl): Trainer control.
"""
super().on_train_begin(args, state, control, **kwargs)
self.trainer.accelerator.wait_for_everyone()
active_session().initialize(
model=self.trainer.model,
optimizer=self.trainer.optimizer,
start=state.epoch,
recipe=self.recipe,
)
self.trainer.accelerator.wait_for_everyone()
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the beginning of a training step. Triggers batch_start callback.
"""
super().on_step_begin(args, state, control, **kwargs)
session_callbacks.batch_start()
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of a training step. Triggers optimizer and batch_end callbacks.
"""
super().on_step_end(args, state, control, **kwargs)
session_callbacks.optim_pre_step()
session_callbacks.optim_post_step()
session_callbacks.batch_end()
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> None:
"""
Called at the end of training. Finalizes the compression session.
"""
super().on_train_end(args, state, control, **kwargs)
active_session().finalize()
self.trainer.compute_loss_func = self.original_compute_loss
class LLMCompressorPlugin(BasePlugin):
"""
Sparse Finetuning plugin for Axolotl integration.
"""
def get_input_args(self) -> str:
"""
Returns the path to the plugin's argument definition.
Returns:
str: Dotted path to the LLMCompressorArgs class.
"""
return "axolotl.integrations.llm_compressor.args.LLMCompressorArgs"
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds Sparse Finetuning callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
LOG.info("Adding Sparse Finetuning callback to the trainer")
callback = LLMCompressorCallbackHandler(
trainer=trainer,
recipe=cfg.llmcompressor.recipe,
)
return [callback]
def compute_loss_wrapper(
compute_loss_func: Callable[Concatenate[Module, P], R],
) -> Callable[Concatenate[Module, P], R]:
"""
Wraps the loss computation function to trigger the loss_calculated callback.
Args:
compute_loss_func (Callable): Original loss computation function.
Returns:
Callable: Wrapped function that also invokes the loss_calculated callback.
"""
@wraps(compute_loss_func)
def compute_and_notify(model: Module, *args: P.args, **kwargs: P.kwargs) -> R:
loss = compute_loss_func(model, *args, **kwargs)
if active_session().lifecycle.initialized_ and model.training:
session_callbacks.loss_calculated(loss=loss)
return loss
return compute_and_notify

View File

@@ -0,0 +1,40 @@
"""Utilities for llmcompressor integration with axolotl."""
from typing import Union
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
)
from transformers import PreTrainedModel, Trainer
def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
Synchronize processes, apply compression hooks, and save the model.
Args:
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
# Only the main process writes the files
if not trainer.accelerator.is_main_process:
return
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -31,6 +31,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2",
"deepseek_v2",
"deepseek_v3",
"glm",
"glm4",
]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -271,6 +271,19 @@ def save_trained_model(
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
elif hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
save_compressed_model(
model=model,
output_dir=cfg.output_dir,
trainer=trainer,
safe_serialization=safe_serialization,
save_compressed=cfg.llmcompressor.save_compressed,
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
@@ -279,6 +292,7 @@ def save_trained_model(
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

View File

@@ -236,18 +236,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg):
"""

View File

@@ -36,7 +36,6 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
@@ -140,6 +139,22 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
hasattr(model_config, "quantization_config")
and model_config.quantization_config
)
# Detect compressed-tensors config
is_compressed_tensors_config = (
quant_config_exists
and model_config.quantization_config.get("quant_method") == "compressed-tensors"
)
if is_compressed_tensors_config:
if model_config.quantization_config.get("config_groups"):
LOG.warning(
"Found `config_groups` in a compressed-tensors config. "
"QAT integration with llmcompressor is not tested."
)
# Skip further quant checks for compressed-tensors
return
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
@@ -834,13 +849,6 @@ class ModelLoader:
del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -862,21 +870,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
):
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
}
quant_config_class = quant_config_class_dict[
self.model_config.quantization_config["quant_method"]
]
self.model_kwargs["quantization_config"] = quant_config_class(
**self.model_config.quantization_config
)
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
@@ -894,8 +902,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
@@ -911,13 +919,6 @@ class ModelLoader:
**bnb_config,
)
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
@@ -1051,12 +1052,6 @@ class ModelLoader:
config=self.model_config,
)
else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1211,7 +1206,7 @@ class ModelLoader:
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1481,16 +1476,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -1,8 +1,8 @@
"""Pydantic models for PEFT-related configuration"""
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Any
from axolotl.utils.schemas.quant import QuantizationConfig
from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel):
@@ -23,11 +23,8 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
adapter: str | None = None
lora_model_dir: str | None = None
@@ -53,6 +50,8 @@ class LoraConfig(BaseModel):
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -75,11 +74,11 @@ class LoraConfig(BaseModel):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("quantization"))
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
):
raise ValueError(
"Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization."
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@@ -87,26 +86,25 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
if self.quantization.bits == 8 or self.load_in_8bit:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.quantization.bits == 4 or self.load_in_4bit:
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.quantization:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@@ -123,24 +121,6 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -1,93 +0,0 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
attention_type: True,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -0,0 +1,104 @@
"""
E2E smoke tests for LLMCompressorPlugin integration
"""
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
MODELS = [
"nm-testing/llama2.c-stories42M-pruned2.4-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
]
@pytest.mark.parametrize(
"base_model", MODELS, ids=["no-checkpoint-recipe", "with-checkpoint-recipe"]
)
@pytest.mark.parametrize(
"save_compressed", [True, False], ids=["save_compressed", "save_uncompressed"]
)
class TestLLMCompressorIntegration:
"""
e2e tests for axolotl.integrations.llm_compressor.LLMCompressorPlugin
"""
@require_torch_2_4_1
def test_llmcompressor_plugin(
self, temp_dir, base_model: str, save_compressed: bool
):
# core cfg
cfg = DictDefault(
{
"base_model": base_model,
"plugins": ["axolotl.integrations.llm_compressor.LLMCompressorPlugin"],
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"llmcompressor": {
"recipe": {
"finetuning_stage": {
"finetuning_modifiers": {
"ConstantPruningModifier": {
"targets": [
"re:.*q_proj.weight",
"re:.*k_proj.weight",
"re:.*v_proj.weight",
"re:.*o_proj.weight",
"re:.*gate_proj.weight",
"re:.*up_proj.weight",
"re:.*down_proj.weight",
],
"start": 0,
},
},
},
},
"save_compressed": save_compressed,
},
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
_check_llmcompressor_model_outputs(temp_dir, save_compressed)
def _check_llmcompressor_model_outputs(temp_dir, save_compressed):
# recipe.yaml should exist
assert (Path(temp_dir) / "recipe.yaml").exists()
# sparsity config exists if save_compressed
if save_compressed:
from compressed_tensors import ModelCompressor
from compressed_tensors.config import Sparse24BitMaskConfig
compressor = ModelCompressor.from_pretrained(temp_dir)
assert compressor is not None
assert isinstance(compressor.sparsity_config, Sparse24BitMaskConfig)

View File

@@ -30,10 +30,8 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
@@ -101,10 +99,8 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",

View File

@@ -171,10 +171,7 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -252,10 +249,7 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
@@ -554,10 +548,7 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -657,10 +648,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}
@@ -734,10 +722,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}
@@ -811,10 +796,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}

View File

@@ -28,10 +28,7 @@ class TestMultiGPUQwen2:
cfg = DictDefault(
{
"base_model": base_model,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -32,10 +32,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
@@ -66,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -89,9 +90,6 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
@@ -123,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -33,10 +33,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
@@ -63,6 +60,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True,
"flash_attention": True,
"pad_to_sequence_len": True,
"load_in_8bit": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
@@ -68,6 +68,7 @@ class TestResumeLlama:
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -34,10 +34,7 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,

View File

@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -35,10 +35,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,
@@ -94,10 +91,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,

View File

@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"load_in_8bit": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"seed": 42,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -1,141 +0,0 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -74,11 +74,7 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": {
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
@@ -97,10 +93,7 @@ class TestValidation(BaseValidation):
"deepspeed": "",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
@@ -114,10 +107,7 @@ class TestValidation(BaseValidation):
"deepspeed": None,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
@@ -316,10 +306,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
}
)
| base_cfg
@@ -331,9 +318,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"backend": "gptq",
},
"gptq": True,
}
)
| base_cfg
@@ -345,24 +330,19 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"bits": None,
},
"load_in_4bit": False,
}
)
| base_cfg
)
with pytest.raises(ValueError, match=r".*bits <= 4*"):
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
)
| base_cfg
@@ -384,10 +364,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
}
)
| base_cfg
@@ -399,10 +376,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"backend": "gptq",
"bits": 4,
},
"gptq": True,
}
)
| base_cfg
@@ -414,9 +388,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"quantization": {
"bits": 4,
},
"load_in_4bit": True,
}
)
| base_cfg
@@ -1004,9 +976,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"quantization": {
"bits": None,
},
"load_in_4bit": True,
}
)
| minimal_cfg
@@ -1014,16 +984,29 @@ class TestValidation(BaseValidation):
with pytest.raises(
ValueError,
match=r"Quantization is not supported without setting an adapter.*",
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"quantization": {
"bits": 4,
},
"load_in_8bit": True,
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora",
}
)
@@ -1035,9 +1018,7 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"quantization": {
"bits": 8,
},
"load_in_8bit": True,
"adapter": "lora",
}
)

View File

@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
from datasets import Dataset
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1,
}
)
self.cfg_1 = validate_config(self.cfg_1)
normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")

View File

@@ -21,10 +21,8 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
"load_in_4bit": False,
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,