Compare commits

..

43 Commits

Author SHA1 Message Date
Sunny Liu
0179021780 fix attribute error 2025-04-21 22:29:24 -04:00
Sunny Liu
c4910da015 update more tests + better hqq validation 2025-04-21 22:17:08 -04:00
Sunny Liu
db7e92f6a6 check if self.cfg.quantization exists when directly setting load_in_4bit 2025-04-21 21:42:23 -04:00
Sunny Liu
136b37e4d4 restore support for legacy cfg.load_in_xbit 2025-04-21 21:32:01 -04:00
Sunny Liu
92644513c4 update relora 2025-04-21 21:22:44 -04:00
Sunny Liu
266ef3f479 skip set_quant_config if quantization not given 2025-04-21 17:17:41 -04:00
Sunny Liu
fcef8c95fe skip set_quant_config if quantization not given 2025-04-21 17:17:20 -04:00
Sunny Liu
136407c556 update multigpu/test_qwen2 2025-04-21 17:04:17 -04:00
Sunny Liu
3251b3235f update test_mixtral 2025-04-21 17:01:07 -04:00
Sunny Liu
1aa9f7d952 update multigpu/test_eval, multigpu/test_llama 2025-04-21 16:49:08 -04:00
Sunny Liu
a20e753321 update test_falcon_samplepack 2025-04-21 16:29:49 -04:00
Sunny Liu
cb121ab91b update test_mixtral [skip e2e] 2025-04-21 16:27:26 -04:00
Sunny Liu
b59640a4c7 amend model loading for hqq + fix hqq version 2025-04-21 15:53:43 -04:00
Sunny Liu
f0a189131b amend model loading for hqq + fix hqq version 2025-04-21 15:53:29 -04:00
Sunny Liu
c8fb5baad6 amend unittests pt2 2025-04-21 13:28:52 -04:00
Sunny Liu
9be971d47c update test_models.py to conform to new quantization config 2025-04-21 11:34:37 -04:00
Sunny Liu
ffd4ef1ece nit 2025-04-21 11:28:59 -04:00
Sunny Liu
320aff1867 update config doc 2025-04-21 10:59:04 -04:00
Sunny Liu
ac24eba2ac include HQQLinear in find target_linear 2025-04-21 10:36:39 -04:00
Sunny Liu
8a5ad8aee3 typo 2025-04-21 10:36:39 -04:00
Sunny Liu
843b50fdaa rigorous qlora validation 2025-04-21 10:36:39 -04:00
Sunny Liu
098ffcc5a2 removed redundant hqq config validation 2025-04-21 10:36:39 -04:00
Sunny Liu
ba8e29c841 quantization config refactoring - better integration 2025-04-21 10:36:39 -04:00
Sunny Liu
143b2e082c nit [skip e2e] 2025-04-21 10:36:39 -04:00
Sunny Liu
aba484de97 WIP quant config refactor 2025-04-21 10:36:39 -04:00
Sunny Liu
f6f5f89c6d fix more typo 2025-04-21 10:36:39 -04:00
Sunny Liu
8926fe9981 lax config requirement - qlora + hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
987c5217a0 fix typos 2025-04-21 10:36:39 -04:00
Sunny Liu
feaef03cb9 didn't realise model_config.quantization_config is just a regular dict 2025-04-21 10:36:39 -04:00
Sunny Liu
ba5d917845 add e2e test for hqq training 2025-04-21 10:36:39 -04:00
Sunny Liu
0e9b060b4d add doc + requirement for hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
0c40d12a18 more comprehensive hqq config options 2025-04-21 10:36:39 -04:00
Sunny Liu
f55b3c805b hqq_nbits triggers prepare_model_for_kbit_training 2025-04-21 10:36:39 -04:00
Sunny Liu
a64601f957 fix wrong variable name 2025-04-21 10:36:39 -04:00
Sunny Liu
eb7bc70b99 fix dumb mistake 2025-04-21 10:36:39 -04:00
Sunny Liu
db6c76b147 forgot to return data in check 2025-04-21 10:36:39 -04:00
Sunny Liu
99730ce40a hqq integration 2025-04-21 10:36:39 -04:00
Wing Lian
7651550850 make sure to download fixtures for kd test (#2541)
* make sure to download fixtures for kd test

* use same alpaca dataset
2025-04-21 10:31:50 -04:00
Wing Lian
341e95aac9 prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] 2025-04-21 10:31:35 -04:00
Catgat
b882dfb63f Fixed Rex Scheduler Warm Up (#2535) [skip ci]
* Fixed Rex Scheduler Warm Up

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-21 10:30:55 -04:00
Wing Lian
b640db1dbc don't run multigpu tests twice, run SP in separate test (#2542)
* don't run multigpu tests twice, run SP in separate test

* fix multiline
2025-04-21 10:24:13 -04:00
Chiwan Park
4ce469d32e fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
2025-04-18 09:57:40 -07:00
Wing Lian
60a8f0958d zero val fix for beta (#2538) 2025-04-17 17:27:19 -07:00
40 changed files with 597 additions and 194 deletions

View File

@@ -1,13 +1,10 @@
#!/bin/bash
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
@@ -17,6 +14,11 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -55,20 +55,46 @@ overrides_of_model_config:
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq:
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
load_in_8bit:
# Use bitsandbytes 4 bit
load_in_4bit:

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.6
liger-kernel==0.5.8
# END section
packaging==23.2
@@ -22,6 +22,7 @@ hf_xet==1.0.0
optimum==1.16.2
hf_transfer
hqq==0.2.5
sentencepiece
gradio==5.23.3

View File

@@ -1040,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- gemma3
- granite
- jamba
- llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,

View File

@@ -236,6 +236,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg):
"""

View File

@@ -3,6 +3,7 @@
import functools
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
# other ranks, we just need to present a fake dataset
if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
@@ -833,6 +834,13 @@ class ModelLoader:
del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -854,21 +862,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
):
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
}
quant_config_class = quant_config_class_dict[
self.model_config.quantization_config["quant_method"]
]
self.model_kwargs["quantization_config"] = quant_config_class(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
@@ -886,8 +894,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
@@ -903,6 +911,13 @@ class ModelLoader:
**bnb_config,
)
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
@@ -1036,6 +1051,12 @@ class ModelLoader:
config=self.model_config,
)
else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1190,7 +1211,7 @@ class ModelLoader:
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1460,7 +1481,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
self.last_step = max(last_step - 1, 0)
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:

View File

@@ -660,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -1,9 +1,9 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.quant import QuantizationConfig
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None
lora_model_dir: str | None = None
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
and (data.get("quantization"))
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
"Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization."
)
return data
@@ -86,25 +87,26 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.load_in_4bit:
if self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.quantization:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@@ -121,6 +123,24 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -0,0 +1,93 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(
@@ -496,12 +522,6 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(scope="session", autouse=True)
def download_tiny_llama_7m_model():
# download the model
return snapshot_download_w_retry("axolotl-ai-internal/llama-7m", repo_type="model")
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

View File

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"
@@ -93,7 +93,7 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
)
@pytest.mark.parametrize(

View File

@@ -0,0 +1,2 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,8 +49,9 @@ class TestPackedFlex:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
@@ -89,5 +90,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -30,8 +30,10 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
@@ -99,8 +101,10 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",

View File

@@ -96,5 +96,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
)

View File

@@ -43,7 +43,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
@@ -94,7 +94,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -105,7 +105,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"eval_sample_packing": False,
@@ -159,19 +159,22 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -244,12 +247,15 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
@@ -326,7 +332,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"val_set_size": 0.01,
"special_tokens": {
@@ -385,7 +391,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -396,7 +402,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
@@ -457,7 +463,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_2_6_0
@@ -475,7 +481,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
@@ -538,7 +544,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -548,7 +554,10 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -618,7 +627,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -648,13 +657,16 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
@@ -702,7 +714,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -722,13 +734,16 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
@@ -776,7 +791,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -796,13 +811,16 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
else:
adapter = {}
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
@@ -850,7 +868,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.skip(
@@ -860,7 +878,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-internal/llama-7m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"fix_untrained_tokens": True,
"sequence_len": 512,
"val_set_size": 0.0,
@@ -917,5 +935,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
)

View File

@@ -28,7 +28,10 @@ class TestMultiGPUQwen2:
cfg = DictDefault(
{
"base_model": base_model,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,

View File

@@ -80,7 +80,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_lt_2_6_0
@@ -138,5 +138,5 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -86,5 +86,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
)

View File

@@ -32,7 +32,10 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,

View File

@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,

View File

@@ -33,7 +33,10 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,

View File

@@ -46,8 +46,9 @@ class TestResumeLlama:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 2,

View File

@@ -80,7 +80,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -130,7 +130,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
@@ -185,5 +185,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
@@ -69,5 +70,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -34,7 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,

View File

@@ -84,5 +84,5 @@ class TestPretrainLlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
"Train Loss is too high",
)

View File

@@ -35,7 +35,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True,
"sequence_len": 1024,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,
@@ -91,7 +94,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 1024,
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
"lora_r": 4,
"lora_alpha": 8,

View File

@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
@@ -68,5 +69,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -0,0 +1,141 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -73,6 +73,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train loss (%s) is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora",
}
| minimal_cfg
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
"deepspeed": "",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
}
| minimal_cfg
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
"deepspeed": None,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora",
}
| minimal_cfg
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
}
)
| base_cfg
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
"quantization": {
"backend": "gptq",
},
}
)
| base_cfg
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": False,
"quantization": {
"bits": None,
},
}
)
| base_cfg
)
with pytest.raises(ValueError, match=r".*4bit.*"):
with pytest.raises(ValueError, match=r".*bits <= 4*"):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
}
)
| base_cfg
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
}
)
| base_cfg
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
"quantization": {
"backend": "gptq",
"bits": 4,
},
}
)
| base_cfg
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
"quantization": {
"bits": 4,
},
}
)
| base_cfg
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"load_in_4bit": True,
"quantization": {
"bits": None,
},
}
)
| minimal_cfg
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
match=r"Quantization is not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_8bit": True,
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"quantization": {
"bits": 4,
},
"adapter": "qlora",
}
)
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
cfg = (
DictDefault(
{
"load_in_8bit": True,
"quantization": {
"bits": 8,
},
"adapter": "lora",
}
)

View File

@@ -8,7 +8,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
from axolotl.utils.callbacks.perplexity import Perplexity
MODEL_NAME = "axolotl-ai-internal/llama-7m"
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M"
@fixture()
@@ -36,7 +36,7 @@ One day, a little fish named Fin was swimming near the shore. He saw a big crab
"""
result = metric.compute(model, [sample_text])
ppl = result["score"]
assert round(ppl, 2) == 75.14
assert round(ppl, 2) == 7.41
def test_perplexity_short(model, metric):
@@ -44,4 +44,4 @@ def test_perplexity_short(model, metric):
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute(model, [sample_text])
ppl = result["score"]
assert round(ppl, 2) == 70.54
assert round(ppl, 2) == 10.33

View File

@@ -21,8 +21,10 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,