Compare commits

..

13 Commits

Author SHA1 Message Date
Sung Ching Liu
f8e92407ff Update src/axolotl/common/datasets.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sung Ching Liu
c12906134d Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
8154d26614 nit 2025-04-17 09:47:14 -04:00
Sunny Liu
fefcbc300d barebone-ify the test so we get rid of unneeded processes 2025-04-17 09:47:14 -04:00
Sunny Liu
7d479348ee custom reward function loading, proeprly done 2025-04-17 09:47:14 -04:00
bursteratom
ce0259db13 add outputdir 2025-04-17 09:47:14 -04:00
Sung Ching Liu
2798817cf9 Update tests/e2e/solo/test_grpo.py
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-17 09:47:14 -04:00
Sunny Liu
0e1b081e49 add unit test 2025-04-17 09:47:14 -04:00
Sunny Liu
8df37ad91f propoer import from file_path after all else fails 2025-04-17 09:47:14 -04:00
Sung Ching Liu
9b74298328 Update src/axolotl/prompt_strategies/base.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-04-17 09:47:14 -04:00
Sunny Liu
ae8738aa87 skip check_datasets_label during debug for grpo 2025-04-17 09:47:14 -04:00
Sunny Liu
ec52561a0c import from filepath if can't import_module 2025-04-17 09:47:14 -04:00
Sunny Liu
eadb16c709 test import-wihtin-import relative path 2025-04-17 09:47:14 -04:00
35 changed files with 308 additions and 587 deletions

View File

@@ -1,10 +1,13 @@
#!/bin/bash #!/bin/bash
set -e set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \ --cov=axolotl \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
@@ -14,11 +17,6 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -55,46 +55,20 @@ overrides_of_model_config:
overrides_of_model_kwargs: overrides_of_model_kwargs:
# use_cache: False # use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: load_in_8bit: true
# Use bitsandbytes 4 bit # Use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.8 liger-kernel==0.5.6
# END section # END section
packaging==23.2 packaging==23.2
@@ -22,7 +22,6 @@ hf_xet==1.0.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
hqq==0.2.5
sentencepiece sentencepiece
gradio==5.23.3 gradio==5.23.3

View File

@@ -129,17 +129,19 @@ def load_preference_datasets(
total_num_steps = None total_num_steps = None
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") if not cfg.rl == "grpo":
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples, check_dataset_labels(
tokenizer, train_samples,
num_examples=cli_args.debug_num_examples, tokenizer,
text_only=cli_args.debug_text_only, num_examples=cli_args.debug_num_examples,
rl_mode=True, text_only=cli_args.debug_text_only,
) rl_mode=True,
)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,

View File

@@ -1040,11 +1040,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None: if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
elif self.cfg.rl_beta is not None: if self.cfg.orpo_alpha:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2 - deepseek_v2
- gemma - gemma
- gemma2 - gemma2
- gemma3 - gemma3 (partial support, no support for FLCE yet)
- granite - granite
- jamba - jamba
- llama - llama

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
) )
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4": elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import ( from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4, apply_liger_kernel_to_llama4,

View File

@@ -4,30 +4,73 @@ module for base dataset transform strategies
import importlib import importlib
import logging import logging
import sys
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def import_from_path(module_name: str, file_path: str):
"""
Import a module from a file path.
Args:
module_name: Name of the module.
file_path: Path to the file.
Returns:
module: The imported module.
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Could not create module spec for: {file_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
loader = importlib.machinery.SourceFileLoader(module_name, file_path)
spec.loader = loader
loader.exec_module(module)
return module
def load(strategy, cfg, module_base=None, **kwargs): def load(strategy, cfg, module_base=None, **kwargs):
try: if len(strategy.split(".")) == 1:
if len(strategy.split(".")) == 1: strategy = strategy + ".default"
strategy = strategy + ".default" load_fn = strategy.split(".")[-1]
load_fn = strategy.split(".")[-1] func = None
if len(strategy.split(".")) > 1: if len(strategy.split(".")) > 1:
try: try:
importlib.import_module( mod = importlib.import_module(
strategy.split(".")[-2], strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]), ".".join(strategy.split(".")[:-2]),
) )
module_base = ".".join(strategy.split(".")[:-2]) func = getattr(mod, load_fn)
strategy = strategy.split(".")[-2] return func(cfg, **kwargs)
except ModuleNotFoundError: except ModuleNotFoundError:
strategy = "." + ".".join(strategy.split(".")[:-1]) pass
else:
strategy = "." + ".".join(strategy.split(".")[:-1]) try:
mod = importlib.import_module(
"." + ".".join(strategy.split(".")[:-1]), module_base
)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except ModuleNotFoundError:
pass
try:
file_path = "/".join(strategy.split(".")[:-1]) + ".py"
module_name = strategy.split(".")[-2]
mod = import_from_path(module_name, file_path)
func = getattr(mod, load_fn)
if func is not None:
return func(cfg, **kwargs)
except FileNotFoundError:
pass
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base) mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn) func = getattr(mod, load_fn)
return func(cfg, **kwargs) return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}") LOG.warning(f"unable to load strategy {strategy}")
return None return func

View File

@@ -236,18 +236,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg): def normalize_cfg_datasets(cfg):
""" """

View File

@@ -3,7 +3,6 @@
import functools import functools
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -118,27 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from iter_ds = load_dataset(
# other ranks, we just need to present a fake dataset path, streaming=True, split=split, name=name, data_files=data_files
if ( )
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip) iter_ds = iter_ds.skip(skip)

View File

@@ -36,7 +36,6 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
GPTQConfig, GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
@@ -834,13 +833,6 @@ class ModelLoader:
del self.model_kwargs["device_map"] del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None: def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -862,21 +854,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config") and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"] and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"] in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
): ):
quant_config_class_dict = { if self.model_config.quantization_config["quant_method"] == "gptq":
"gptq": GPTQConfig, self.model_kwargs["quantization_config"] = GPTQConfig(
"awq": AwqConfig, **self.model_config.quantization_config
"bitsandbytes": BitsAndBytesConfig, )
} elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
quant_config_class = quant_config_class_dict[ **self.model_config.quantization_config
self.model_config.quantization_config["quant_method"] )
] elif (
self.model_kwargs["quantization_config"] = quant_config_class( self.model_config.quantization_config["quant_method"] == "bitsandbytes"
**self.model_config.quantization_config ):
) self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
@@ -894,8 +886,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs: if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs) bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
@@ -911,13 +903,6 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610 # no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq: if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None) self.model_kwargs.pop("load_in_8bit", None)
@@ -1051,12 +1036,6 @@ class ModelLoader:
config=self.model_config, config=self.model_config,
) )
else: else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1211,7 +1190,7 @@ class ModelLoader:
if ( if (
not skip_prepare_model_for_kbit_training not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"] and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq) and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training( self.model = prepare_model_for_kbit_training(
@@ -1481,16 +1460,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
from hqq.core.peft import HQQLinearLoRA cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr self.max_lr = max_lr
self.total_steps = total_steps self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps self.num_warmup_steps = num_warmup_steps
self.last_step = max(last_step - 1, 0) self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@@ -660,7 +660,6 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0 data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
): ):
raise ValueError( raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0" "eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -1,8 +1,8 @@
"""Pydantic models for PEFT-related configuration""" """Pydantic models for PEFT-related configuration"""
from pydantic import BaseModel, Field, field_validator, model_validator from typing import Any
from axolotl.utils.schemas.quant import QuantizationConfig from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -23,11 +23,8 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset""" """Peft / LoRA configuration subset"""
quantization: QuantizationConfig | None = None load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = None # for internal use load_in_4bit: bool | None = Field(default=False)
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None adapter: str | None = None
lora_model_dir: str | None = None lora_model_dir: str | None = None
@@ -53,6 +50,8 @@ class LoraConfig(BaseModel):
}, },
) )
lora_on_cpu: bool | None = None lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field( loraplus_lr_ratio: float | None = Field(
default=None, default=None,
@@ -75,11 +74,11 @@ class LoraConfig(BaseModel):
if ( if (
not data.get("adapter") not data.get("adapter")
and not data.get("inference") and not data.get("inference")
and (data.get("quantization")) and (data.get("load_in_8bit") or data.get("load_in_4bit"))
): ):
raise ValueError( raise ValueError(
"Quantization is not supported without setting an adapter." "load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off Quantization." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
return data return data
@@ -87,26 +86,25 @@ class LoraConfig(BaseModel):
def validate_qlora(self): def validate_qlora(self):
if self.adapter == "qlora": if self.adapter == "qlora":
if self.merge_lora: if self.merge_lora:
if self.quantization.bits == 8 or self.load_in_8bit: # can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit") raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq": if self.gptq:
raise ValueError("Can't merge qlora if using gptq") raise ValueError("Can't merge qlora if gptq")
if self.quantization.bits == 4 or self.load_in_4bit: if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit") raise ValueError("Can't merge qlora if loaded in 4bit")
else: else:
if self.quantization: if self.load_in_8bit:
if self.quantization.bits == 8 or self.load_in_8bit: raise ValueError("Can't load qlora in 8bit")
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq": if self.gptq:
raise ValueError("Can't load qlora if using gptq") raise ValueError("Can't load qlora if gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self return self
@field_validator("loraplus_lr_embedding") @field_validator("loraplus_lr_embedding")
@@ -123,24 +121,6 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0 data["lora_dropout"] = 0.0
return data return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel): class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset""" """ReLoRA configuration subset"""

View File

@@ -1,93 +0,0 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture(): def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(

View File

@@ -1,2 +0,0 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,9 +49,8 @@ class TestPackedFlex:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -30,10 +30,8 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": { "load_in_8bit": False,
"backend": "bnb", "load_in_4bit": True,
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",
@@ -101,10 +99,8 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": { "load_in_8bit": False,
"backend": "bnb", "load_in_4bit": True,
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",

View File

@@ -171,10 +171,7 @@ class TestMultiGPULlama:
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -252,10 +249,7 @@ class TestMultiGPULlama:
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -554,10 +548,7 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16", "base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora", "adapter": "qlora",
"mean_resizing_embeddings": True, "mean_resizing_embeddings": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
@@ -657,10 +648,7 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
@@ -734,10 +722,7 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
@@ -811,10 +796,7 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}

View File

@@ -28,10 +28,7 @@ class TestMultiGPUQwen2:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": base_model, "base_model": base_model,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"rl": "dpo", "rl": "dpo",
"chat_template": "chatml", "chat_template": "chatml",
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard from ..utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -32,10 +32,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -89,9 +89,6 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"gptq": True, "gptq": True,

View File

@@ -33,10 +33,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -46,9 +46,8 @@ class TestResumeLlama:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 2, "num_epochs": 2,

View File

@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -0,0 +1,85 @@
"""
E2E tests for preprocessing
"""
import logging
import os
import unittest
import transformers
from axolotl.cli.args import PreprocessCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomRewardFunctionLoading(unittest.TestCase):
"""
Test case for GRPO training using single GPU
"""
def _utils_write_rewards(self):
# write cfg to yaml file
with open("rewards.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@with_temp_dir
def test_custom_rewards_fn_preprocess(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"strict": False,
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [
"rewards.rand_reward_func"
], # format: '{file_name}.{fn_name}'
"reward_weights": [1.0],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": "rewards.oai_gsm8k_transform",
},
],
"dataset_prepared_path": temp_dir,
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"learning_rate": 0.000005,
}
)
self._utils_write_rewards()
cfg = validate_config(cfg)
normalize_config(cfg)
parser = transformers.HfArgumentParser(PreprocessCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -34,10 +34,7 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"flash_attention": True, "flash_attention": True,
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,

View File

@@ -35,10 +35,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,
@@ -94,10 +91,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False, "flash_attention": False,
"sequence_len": 1024, "sequence_len": 1024,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,

View File

@@ -40,9 +40,8 @@ class TestPackedLlama(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,

View File

@@ -1,141 +0,0 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -74,11 +74,7 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json", "deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -97,10 +93,7 @@ class TestValidation(BaseValidation):
"deepspeed": "", "deepspeed": "",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -114,10 +107,7 @@ class TestValidation(BaseValidation):
"deepspeed": None, "deepspeed": None,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -316,10 +306,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -331,9 +318,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "gptq": True,
"backend": "gptq",
},
} }
) )
| base_cfg | base_cfg
@@ -345,24 +330,19 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": False,
"bits": None,
},
} }
) )
| base_cfg | base_cfg
) )
with pytest.raises(ValueError, match=r".*bits <= 4*"): with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -384,10 +364,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -399,10 +376,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "gptq": True,
"backend": "gptq",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -414,9 +388,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": True,
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -1004,9 +976,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_4bit": True,
"bits": None,
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -1014,16 +984,29 @@ class TestValidation(BaseValidation):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r"Quantization is not supported without setting an adapter.*", match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
): ):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_8bit": True,
"bits": 4, }
}, )
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
} }
) )
@@ -1035,9 +1018,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_8bit": True,
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
} }
) )

View File

@@ -21,10 +21,8 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM", "model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"quantization": { "load_in_8bit": True,
"backend": "bnb", "load_in_4bit": False,
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"flash_attention": False, "flash_attention": False,
"sample_packing": True, "sample_packing": True,