Compare commits

..

4 Commits

Author SHA1 Message Date
Wing Lian
a0670abc94 add output for train loss in assertian err 2025-04-18 08:11:11 -07:00
Wing Lian
08f287b57f swap llama tests for 7m param model 2025-04-17 09:52:35 -07:00
Wing Lian
b4c7d9c29d fix perplexity scores 2025-04-17 07:58:53 -07:00
Wing Lian
d2637fb01d first pass at modifying tests to use llama-7m 2025-04-16 21:14:04 -07:00
40 changed files with 194 additions and 597 deletions

View File

@@ -1,10 +1,13 @@
#!/bin/bash #!/bin/bash
set -e set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection) # Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \ /workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \ --cov=axolotl \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
@@ -14,11 +17,6 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov-append \ --cov-append \
--cov-report=xml:multigpu-coverage.xml --cov-report=xml:multigpu-coverage.xml
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov # Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -55,46 +55,20 @@ overrides_of_model_config:
overrides_of_model_kwargs: overrides_of_model_kwargs:
# use_cache: False # use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: load_in_8bit: true
# Use bitsandbytes 4 bit # Use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.8 liger-kernel==0.5.6
# END section # END section
packaging==23.2 packaging==23.2
@@ -22,7 +22,6 @@ hf_xet==1.0.0
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
hqq==0.2.5
sentencepiece sentencepiece
gradio==5.23.3 gradio==5.23.3

View File

@@ -1040,11 +1040,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes: if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.trl and self.cfg.trl.beta is not None: if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
elif self.cfg.rl_beta is not None: if self.cfg.orpo_alpha:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2 - deepseek_v2
- gemma - gemma
- gemma2 - gemma2
- gemma3 - gemma3 (partial support, no support for FLCE yet)
- granite - granite
- jamba - jamba
- llama - llama

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect import inspect
import logging import logging
import sys import sys
from functools import partial
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
) )
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy: if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4": elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import ( from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4, apply_liger_kernel_to_llama4,

View File

@@ -236,18 +236,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg): def normalize_cfg_datasets(cfg):
""" """

View File

@@ -3,7 +3,6 @@
import functools import functools
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -118,27 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from iter_ds = load_dataset(
# other ranks, we just need to present a fake dataset path, streaming=True, split=split, name=name, data_files=data_files
if ( )
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip: if skip:
LOG.info(f"Skipping {skip} samples from the dataset") LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip) iter_ds = iter_ds.skip(skip)

View File

@@ -36,7 +36,6 @@ from transformers import (
BitsAndBytesConfig, BitsAndBytesConfig,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
GPTQConfig, GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration, Llama4ForConditionalGeneration,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration, Mistral3ForConditionalGeneration,
@@ -834,13 +833,6 @@ class ModelLoader:
del self.model_kwargs["device_map"] del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None: def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -862,21 +854,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config") and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"] and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"] in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
): ):
quant_config_class_dict = { if self.model_config.quantization_config["quant_method"] == "gptq":
"gptq": GPTQConfig, self.model_kwargs["quantization_config"] = GPTQConfig(
"awq": AwqConfig, **self.model_config.quantization_config
"bitsandbytes": BitsAndBytesConfig, )
} elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
quant_config_class = quant_config_class_dict[ **self.model_config.quantization_config
self.model_config.quantization_config["quant_method"] )
] elif (
self.model_kwargs["quantization_config"] = quant_config_class( self.model_config.quantization_config["quant_method"] == "bitsandbytes"
**self.model_config.quantization_config ):
) self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = { bnb_config = {
"load_in_4bit": True, "load_in_4bit": True,
@@ -894,8 +886,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16 # but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32 bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs: if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs) bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig( self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config, **bnb_config,
@@ -911,13 +903,6 @@ class ModelLoader:
**bnb_config, **bnb_config,
) )
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610 # no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq: if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None) self.model_kwargs.pop("load_in_8bit", None)
@@ -1051,12 +1036,6 @@ class ModelLoader:
config=self.model_config, config=self.model_config,
) )
else: else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained( self.model = self.auto_model_loader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,
@@ -1211,7 +1190,7 @@ class ModelLoader:
if ( if (
not skip_prepare_model_for_kbit_training not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"] and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq) and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training( self.model = prepare_model_for_kbit_training(
@@ -1481,16 +1460,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
from hqq.core.peft import HQQLinearLoRA cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr self.max_lr = max_lr
self.total_steps = total_steps self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps self.num_warmup_steps = num_warmup_steps
self.last_step = max(last_step - 1, 0) self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups: for group in optimizer.param_groups:

View File

@@ -660,7 +660,6 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0 data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
): ):
raise ValueError( raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0" "eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -1,8 +1,8 @@
"""Pydantic models for PEFT-related configuration""" """Pydantic models for PEFT-related configuration"""
from pydantic import BaseModel, Field, field_validator, model_validator from typing import Any
from axolotl.utils.schemas.quant import QuantizationConfig from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -23,11 +23,8 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel): class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset""" """Peft / LoRA configuration subset"""
quantization: QuantizationConfig | None = None load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = None # for internal use load_in_4bit: bool | None = Field(default=False)
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
adapter: str | None = None adapter: str | None = None
lora_model_dir: str | None = None lora_model_dir: str | None = None
@@ -53,6 +50,8 @@ class LoraConfig(BaseModel):
}, },
) )
lora_on_cpu: bool | None = None lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field( loraplus_lr_ratio: float | None = Field(
default=None, default=None,
@@ -75,11 +74,11 @@ class LoraConfig(BaseModel):
if ( if (
not data.get("adapter") not data.get("adapter")
and not data.get("inference") and not data.get("inference")
and (data.get("quantization")) and (data.get("load_in_8bit") or data.get("load_in_4bit"))
): ):
raise ValueError( raise ValueError(
"Quantization is not supported without setting an adapter." "load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off Quantization." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
return data return data
@@ -87,26 +86,25 @@ class LoraConfig(BaseModel):
def validate_qlora(self): def validate_qlora(self):
if self.adapter == "qlora": if self.adapter == "qlora":
if self.merge_lora: if self.merge_lora:
if self.quantization.bits == 8 or self.load_in_8bit: # can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit") raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq": if self.gptq:
raise ValueError("Can't merge qlora if using gptq") raise ValueError("Can't merge qlora if gptq")
if self.quantization.bits == 4 or self.load_in_4bit: if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit") raise ValueError("Can't merge qlora if loaded in 4bit")
else: else:
if self.quantization: if self.load_in_8bit:
if self.quantization.bits == 8 or self.load_in_8bit: raise ValueError("Can't load qlora in 8bit")
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq": if self.gptq:
raise ValueError("Can't load qlora if using gptq") raise ValueError("Can't load qlora if gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self return self
@field_validator("loraplus_lr_embedding") @field_validator("loraplus_lr_embedding")
@@ -123,24 +121,6 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0 data["lora_dropout"] = 0.0
return data return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel): class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset""" """ReLoRA configuration subset"""

View File

@@ -1,93 +0,0 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset") snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture(): def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model") snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture(): def download_llama_1b_model_fixture():
# download the tokenizer only # download the tokenizer only
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(
@@ -522,6 +496,12 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"] return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(scope="session", autouse=True)
def download_tiny_llama_7m_model():
# download the model
return snapshot_download_w_retry("axolotl-ai-internal/llama-7m", repo_type="model")
# # pylint: disable=redefined-outer-name,unused-argument # # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures( # def test_load_fixtures(
# download_smollm2_135m_model, # download_smollm2_135m_model,

View File

@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists() assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
) )

View File

@@ -1,2 +0,0 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,9 +49,8 @@ class TestPackedFlex:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -90,5 +89,5 @@ class TestPackedFlex:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -30,10 +30,8 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": { "load_in_8bit": False,
"backend": "bnb", "load_in_4bit": True,
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",
@@ -101,10 +99,8 @@ class TestMultiGPUEval:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": { "load_in_8bit": False,
"backend": "bnb", "load_in_4bit": True,
"bits": 4,
},
"strict": False, "strict": False,
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "qlora", "adapter": "qlora",

View File

@@ -96,5 +96,5 @@ class TestMultiGPUGemma3:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 1.8, "Train loss (%s) is too high"
) )

View File

@@ -43,7 +43,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
@@ -94,7 +94,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -105,7 +105,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": True, "sample_packing": True,
"eval_sample_packing": False, "eval_sample_packing": False,
@@ -159,22 +159,19 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
def test_dpo_lora_ddp(self, temp_dir): def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -247,15 +244,12 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"sample_packing": False, "sample_packing": False,
"eval_sample_packing": False, "eval_sample_packing": False,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
@@ -332,7 +326,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.01, "val_set_size": 0.01,
"special_tokens": { "special_tokens": {
@@ -391,7 +385,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -402,7 +396,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -463,7 +457,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@require_torch_2_6_0 @require_torch_2_6_0
@@ -481,7 +475,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
@@ -544,7 +538,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.1, "Train loss (%s) is too high"
) )
def test_fsdp_qlora_prequant_packed(self, temp_dir): def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -554,10 +548,7 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16", "base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora", "adapter": "qlora",
"mean_resizing_embeddings": True, "mean_resizing_embeddings": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
@@ -627,7 +618,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -657,16 +648,13 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -714,7 +702,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -734,16 +722,13 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -791,7 +776,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -811,16 +796,13 @@ class TestMultiGPULlama:
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
else: else:
adapter = {} adapter = {}
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -868,7 +850,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@pytest.mark.skip( @pytest.mark.skip(
@@ -878,7 +860,7 @@ class TestMultiGPULlama:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "axolotl-ai-internal/llama-7m",
"fix_untrained_tokens": True, "fix_untrained_tokens": True,
"sequence_len": 512, "sequence_len": 512,
"val_set_size": 0.0, "val_set_size": 0.0,
@@ -935,5 +917,5 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 4.0, "Train loss (%s) is too high"
) )

View File

@@ -28,10 +28,7 @@ class TestMultiGPUQwen2:
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": base_model, "base_model": base_model,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"rl": "dpo", "rl": "dpo",
"chat_template": "chatml", "chat_template": "chatml",
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -80,7 +80,7 @@ class TestMultiGPURay:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )
@require_torch_lt_2_6_0 @require_torch_lt_2_6_0
@@ -138,5 +138,5 @@ class TestMultiGPURay:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
) )

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard from ..utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -93,7 +93,7 @@ class TestSequenceParallelism:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.6, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -86,5 +86,5 @@ class TestFAXentropyLlama:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 1.5, "Train loss (%s) is too high"
) )

View File

@@ -32,10 +32,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -89,9 +89,6 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"gptq": True, "gptq": True,

View File

@@ -33,10 +33,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 16, "lora_r": 16,
"lora_alpha": 32, "lora_alpha": 32,

View File

@@ -46,9 +46,8 @@ class TestResumeLlama:
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 2, "num_epochs": 2,

View File

@@ -80,7 +80,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )
def test_unsloth_llama_qlora_unpacked(self, temp_dir): def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -130,7 +130,7 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -185,5 +185,5 @@ class TestUnslothQLoRA:
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -70,5 +69,5 @@ class TestPackedFlex(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -34,10 +34,7 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"flash_attention": True, "flash_attention": True,
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"lora_r": 8, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,

View File

@@ -84,5 +84,5 @@ class TestPretrainLlama:
temp_dir + "/runs", temp_dir + "/runs",
"train/train_loss", "train/train_loss",
loss_threshold, loss_threshold,
"Train Loss is too high", "Train Loss (%s) is too high",
) )

View File

@@ -35,10 +35,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,
@@ -94,10 +91,7 @@ class TestMixtral(unittest.TestCase):
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF", "tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False, "flash_attention": False,
"sequence_len": 1024, "sequence_len": 1024,
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
"lora_r": 4, "lora_r": 4,
"lora_alpha": 8, "lora_alpha": 8,

View File

@@ -40,9 +40,8 @@ class TestPackedLlama(unittest.TestCase):
}, },
"datasets": [ "datasets": [
{ {
"path": "tatsu-lab/alpaca", "path": "vicgalle/alpaca-gpt4",
"type": "alpaca", "type": "alpaca",
"split": "train[:10%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
@@ -69,5 +68,5 @@ class TestPackedLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
) )

View File

@@ -1,141 +0,0 @@
"""
E2E tests for training with quantized model
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestHQQ(unittest.TestCase):
"""
Test cases for training of HQQ-quantized llama models"""
@with_temp_dir
def test_hqq_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 8,
"group_size": 64,
}
],
"adapter": "lora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)
@with_temp_dir
def test_hqq_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"use_hqq": True,
"hqq_config": [
{
"nbits": 4,
"group_size": 64,
}
],
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
)

View File

@@ -73,6 +73,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.5, "Train loss (%s) is too high"
) )
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -74,11 +74,7 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json", "deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -97,10 +93,7 @@ class TestValidation(BaseValidation):
"deepspeed": "", "deepspeed": "",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -114,10 +107,7 @@ class TestValidation(BaseValidation):
"deepspeed": None, "deepspeed": None,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -316,10 +306,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -331,9 +318,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "gptq": True,
"backend": "gptq",
},
} }
) )
| base_cfg | base_cfg
@@ -345,24 +330,19 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": False,
"bits": None,
},
} }
) )
| base_cfg | base_cfg
) )
with pytest.raises(ValueError, match=r".*bits <= 4*"): with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": True,
"backend": "bnb",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -384,10 +364,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_8bit": True,
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -399,10 +376,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "gptq": True,
"backend": "gptq",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -414,9 +388,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"quantization": { "load_in_4bit": True,
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -1004,9 +976,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_4bit": True,
"bits": None,
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -1014,16 +984,29 @@ class TestValidation(BaseValidation):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r"Quantization is not supported without setting an adapter.*", match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
): ):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_8bit": True,
"bits": 4, }
}, )
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
} }
) )
@@ -1035,9 +1018,7 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"quantization": { "load_in_8bit": True,
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
} }
) )

View File

@@ -8,7 +8,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" MODEL_NAME = "axolotl-ai-internal/llama-7m"
@fixture() @fixture()
@@ -36,7 +36,7 @@ One day, a little fish named Fin was swimming near the shore. He saw a big crab
""" """
result = metric.compute(model, [sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 7.41 assert round(ppl, 2) == 75.14
def test_perplexity_short(model, metric): def test_perplexity_short(model, metric):
@@ -44,4 +44,4 @@ def test_perplexity_short(model, metric):
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun." sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute(model, [sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 10.33 assert round(ppl, 2) == 70.54

View File

@@ -21,10 +21,8 @@ class TestModelsUtils:
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM", "model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"quantization": { "load_in_8bit": True,
"backend": "bnb", "load_in_4bit": False,
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
"flash_attention": False, "flash_attention": False,
"sample_packing": True, "sample_packing": True,