Compare commits

...

43 Commits

Author SHA1 Message Date
Sunny Liu
0179021780 fix attribute error 2025-04-21 22:29:24 -04:00
Sunny Liu
c4910da015 update more tests + better hqq validation 2025-04-21 22:17:08 -04:00
Sunny Liu
db7e92f6a6 check if self.cfg.quantization exists when directly setting load_in_4bit 2025-04-21 21:42:23 -04:00
Sunny Liu
136b37e4d4 restore support for legacy cfg.load_in_xbit 2025-04-21 21:32:01 -04:00
Sunny Liu
92644513c4 update relora 2025-04-21 21:22:44 -04:00
Sunny Liu
266ef3f479 skip set_quant_config if quantization not given 2025-04-21 17:17:41 -04:00
Sunny Liu
fcef8c95fe skip set_quant_config if quantization not given 2025-04-21 17:17:20 -04:00
Sunny Liu
136407c556 update multigpu/test_qwen2 2025-04-21 17:04:17 -04:00
Sunny Liu
3251b3235f update test_mixtral 2025-04-21 17:01:07 -04:00
Sunny Liu
1aa9f7d952 update multigpu/test_eval, multigpu/test_llama 2025-04-21 16:49:08 -04:00
Sunny Liu
a20e753321 update test_falcon_samplepack 2025-04-21 16:29:49 -04:00
Sunny Liu
cb121ab91b update test_mixtral [skip e2e] 2025-04-21 16:27:26 -04:00
Sunny Liu
b59640a4c7 amend model loading for hqq + fix hqq version 2025-04-21 15:53:43 -04:00
Sunny Liu
f0a189131b amend model loading for hqq + fix hqq version 2025-04-21 15:53:29 -04:00
Sunny Liu
c8fb5baad6 amend unittests pt2 2025-04-21 13:28:52 -04:00
Sunny Liu
9be971d47c update test_models.py to conform to new quantization config 2025-04-21 11:34:37 -04:00
Sunny Liu
ffd4ef1ece nit 2025-04-21 11:28:59 -04:00
Sunny Liu
320aff1867 update config doc 2025-04-21 10:59:04 -04:00
Sunny Liu
ac24eba2ac include HQQLinear in find target_linear 2025-04-21 10:36:39 -04:00
Sunny Liu
8a5ad8aee3 typo 2025-04-21 10:36:39 -04:00
Sunny Liu
843b50fdaa rigorous qlora validation 2025-04-21 10:36:39 -04:00
Sunny Liu
098ffcc5a2 removed redundant hqq config validation 2025-04-21 10:36:39 -04:00
Sunny Liu
ba8e29c841 quantization config refactoring - better integration 2025-04-21 10:36:39 -04:00
Sunny Liu
143b2e082c nit [skip e2e] 2025-04-21 10:36:39 -04:00
Sunny Liu
aba484de97 WIP quant config refactor 2025-04-21 10:36:39 -04:00
Sunny Liu
f6f5f89c6d fix more typo 2025-04-21 10:36:39 -04:00
Sunny Liu
8926fe9981 lax config requirement - qlora + hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
987c5217a0 fix typos 2025-04-21 10:36:39 -04:00
Sunny Liu
feaef03cb9 didn't realise model_config.quantization_config is just a regular dict 2025-04-21 10:36:39 -04:00
Sunny Liu
ba5d917845 add e2e test for hqq training 2025-04-21 10:36:39 -04:00
Sunny Liu
0e9b060b4d add doc + requirement for hqq 2025-04-21 10:36:39 -04:00
Sunny Liu
0c40d12a18 more comprehensive hqq config options 2025-04-21 10:36:39 -04:00
Sunny Liu
f55b3c805b hqq_nbits triggers prepare_model_for_kbit_training 2025-04-21 10:36:39 -04:00
Sunny Liu
a64601f957 fix wrong variable name 2025-04-21 10:36:39 -04:00
Sunny Liu
eb7bc70b99 fix dumb mistake 2025-04-21 10:36:39 -04:00
Sunny Liu
db6c76b147 forgot to return data in check 2025-04-21 10:36:39 -04:00
Sunny Liu
99730ce40a hqq integration 2025-04-21 10:36:39 -04:00
Wing Lian
7651550850 make sure to download fixtures for kd test (#2541)
* make sure to download fixtures for kd test

* use same alpaca dataset
2025-04-21 10:31:50 -04:00
Wing Lian
341e95aac9 prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] 2025-04-21 10:31:35 -04:00
Catgat
b882dfb63f Fixed Rex Scheduler Warm Up (#2535) [skip ci]
* Fixed Rex Scheduler Warm Up

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-21 10:30:55 -04:00
Wing Lian
b640db1dbc don't run multigpu tests twice, run SP in separate test (#2542)
* don't run multigpu tests twice, run SP in separate test

* fix multiline
2025-04-21 10:24:13 -04:00
Chiwan Park
4ce469d32e fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
2025-04-18 09:57:40 -07:00
Wing Lian
60a8f0958d zero val fix for beta (#2538) 2025-04-17 17:27:19 -07:00
32 changed files with 558 additions and 149 deletions

View File

@@ -1,13 +1,10 @@
#!/bin/bash #!/bin/bash
set -e set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \ --cov=axolotl \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
@@ -17,6 +14,11 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -55,20 +55,46 @@ overrides_of_model_config:
overrides_of_model_kwargs: overrides_of_model_kwargs:
# use_cache: False # use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq:
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit:
# Use bitsandbytes 4 bit # Use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.6 liger-kernel==0.5.8
# END section # END section
packaging==23.2 packaging==23.2
@@ -22,6 +22,7 @@ hf_xet==1.0.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
hqq==0.2.5
sentencepiece sentencepiece
gradio==5.23.3 gradio==5.23.3

View File

@@ -1040,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta: if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta training_args_kwargs["beta"] = self.cfg.trl.beta
if self.cfg.orpo_alpha: elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2 - deepseek_v2
- gemma - gemma
- gemma2 - gemma2
- gemma3 (partial support, no support for FLCE yet) - gemma3
- granite - granite
- jamba - jamba
- llama - llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
) )
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4": elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import ( from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4, apply_liger_kernel_to_llama4,

View File

@@ -236,6 +236,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg): def normalize_cfg_datasets(cfg):
""" """

View File

@@ -3,6 +3,7 @@
import functools import functools
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
iter_ds = load_dataset( # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
path, streaming=True, split=split, name=name, data_files=data_files # other ranks, we just need to present a fake dataset
) if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip) iter_ds = iter_ds.skip(skip)

View File

@@ -36,6 +36,7 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
GPTQConfig, GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
@@ -833,6 +834,13 @@ class ModelLoader:
del self.model_kwargs["device_map"] del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None: def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -854,21 +862,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config") and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"] and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"] in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
): ):
if self.model_config.quantization_config["quant_method"] == "gptq": quant_config_class_dict = {
self.model_kwargs["quantization_config"] = GPTQConfig( "gptq": GPTQConfig,
**self.model_config.quantization_config "awq": AwqConfig,
) "bitsandbytes": BitsAndBytesConfig,
elif self.model_config.quantization_config["quant_method"] == "awq": }
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config quant_config_class = quant_config_class_dict[
) self.model_config.quantization_config["quant_method"]
elif ( ]
self.model_config.quantization_config["quant_method"] == "bitsandbytes" self.model_kwargs["quantization_config"] = quant_config_class(
): **self.model_config.quantization_config
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( )
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
@@ -886,8 +894,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.bnb_config_kwargs: if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs) bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
@@ -903,6 +911,13 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610 # no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq: if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None) self.model_kwargs.pop("load_in_8bit", None)
@@ -1036,6 +1051,12 @@ class ModelLoader:
config=self.model_config, config=self.model_config,
) )
else: else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1190,7 +1211,7 @@ class ModelLoader:
if ( if (
not skip_prepare_model_for_kbit_training not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"] and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training( self.model = prepare_model_for_kbit_training(
@@ -1460,7 +1481,16 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr self.max_lr = max_lr
self.total_steps = total_steps self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1 self.last_step = max(last_step - 1, 0)
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@@ -660,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0 data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
): ):
raise ValueError( raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0" "eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -1,9 +1,9 @@
"""Pydantic models for PEFT-related configuration""" """Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.quant import QuantizationConfig
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
"""LoftQ configuration subset""" """LoftQ configuration subset"""
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset""" """Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False) quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = Field(default=False) load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None adapter: str | None = None
lora_model_dir: str | None = None lora_model_dir: str | None = None
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
}, },
) )
lora_on_cpu: bool | None = None lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field( loraplus_lr_ratio: float | None = Field(
default=None, default=None,
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
if ( if (
not data.get("adapter") not data.get("adapter")
and not data.get("inference") and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit")) and (data.get("quantization"))
): ):
raise ValueError( raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training." "Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off Quantization."
) )
return data return data
@@ -86,25 +87,26 @@ class LoraConfig(BaseModel):
def validate_qlora(self): def validate_qlora(self):
if self.adapter == "qlora": if self.adapter == "qlora":
if self.merge_lora: if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit if self.quantization.bits == 8 or self.load_in_8bit:
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit") raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq: if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if gptq") raise ValueError("Can't merge qlora if using gptq")
if self.load_in_4bit: if self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit") raise ValueError("Can't merge qlora if loaded in 4bit")
else: else:
if self.load_in_8bit: if self.quantization:
raise ValueError("Can't load qlora in 8bit") if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq: if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if gptq") raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self return self
@field_validator("loraplus_lr_embedding") @field_validator("loraplus_lr_embedding")
@@ -121,6 +123,24 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0 data["lora_dropout"] = 0.0
return data return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel): class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset""" """ReLoRA configuration subset"""

View File

@@ -0,0 +1,93 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture(): def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(

View File

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -0,0 +1,2 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,8 +49,9 @@ class TestPackedFlex:
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -30,8 +30,10 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False, "quantization": {
"load_in_4bit": True, "backend": "bnb",
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",
@@ -99,8 +101,10 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False, "quantization": {
"load_in_4bit": True, "backend": "bnb",
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",

View File

@@ -171,7 +171,10 @@ class TestMultiGPULlama:
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -249,7 +252,10 @@ class TestMultiGPULlama:
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -548,7 +554,10 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16", "base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora", "adapter": "qlora",
"mean_resizing_embeddings": True, "mean_resizing_embeddings": True,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
@@ -648,7 +657,10 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
@@ -722,7 +734,10 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
@@ -796,7 +811,10 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}

View File

@@ -28,7 +28,10 @@ class TestMultiGPUQwen2:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": base_model, "base_model": base_model,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"rl": "dpo", "rl": "dpo",
"chat_template": "chatml", "chat_template": "chatml",
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -32,7 +32,10 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"gptq": True, "gptq": True,

View File

@@ -33,7 +33,10 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -46,8 +46,9 @@ class TestResumeLlama:
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 2, "num_epochs": 2,

View File

@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -34,7 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"flash_attention": True, "flash_attention": True,
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,

View File

@@ -35,7 +35,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,
@@ -91,7 +94,10 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False, "flash_attention": False,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,

View File

@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "tatsu-lab/alpaca",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -0,0 +1,141 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json", "deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
"deepspeed": "", "deepspeed": "",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
"deepspeed": None, "deepspeed": None,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"gptq": True, "quantization": {
"backend": "gptq",
},
} }
) )
| base_cfg | base_cfg
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": False, "quantization": {
"bits": None,
},
} }
) )
| base_cfg | base_cfg
) )
with pytest.raises(ValueError, match=r".*4bit.*"): with pytest.raises(ValueError, match=r".*bits <= 4*"):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"gptq": True, "quantization": {
"backend": "gptq",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": True, "quantization": {
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_4bit": True, "quantization": {
"bits": None,
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", match=r"Quantization is not supported without setting an adapter.*",
): ):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_8bit": True, "quantization": {
} "bits": 4,
) },
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
} }
) )
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_8bit": True, "quantization": {
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
} }
) )

View File

@@ -21,8 +21,10 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM", "model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True, "quantization": {
"load_in_4bit": False, "backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"flash_attention": False, "flash_attention": False,
"sample_packing": True, "sample_packing": True,