Compare commits
37 Commits
attention_
...
feat_hqq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0179021780 | ||
|
|
c4910da015 | ||
|
|
db7e92f6a6 | ||
|
|
136b37e4d4 | ||
|
|
92644513c4 | ||
|
|
266ef3f479 | ||
|
|
fcef8c95fe | ||
|
|
136407c556 | ||
|
|
3251b3235f | ||
|
|
1aa9f7d952 | ||
|
|
a20e753321 | ||
|
|
cb121ab91b | ||
|
|
b59640a4c7 | ||
|
|
f0a189131b | ||
|
|
c8fb5baad6 | ||
|
|
9be971d47c | ||
|
|
ffd4ef1ece | ||
|
|
320aff1867 | ||
|
|
ac24eba2ac | ||
|
|
8a5ad8aee3 | ||
|
|
843b50fdaa | ||
|
|
098ffcc5a2 | ||
|
|
ba8e29c841 | ||
|
|
143b2e082c | ||
|
|
aba484de97 | ||
|
|
f6f5f89c6d | ||
|
|
8926fe9981 | ||
|
|
987c5217a0 | ||
|
|
feaef03cb9 | ||
|
|
ba5d917845 | ||
|
|
0e9b060b4d | ||
|
|
0c40d12a18 | ||
|
|
f55b3c805b | ||
|
|
a64601f957 | ||
|
|
eb7bc70b99 | ||
|
|
db6c76b147 | ||
|
|
99730ce40a |
@@ -55,20 +55,46 @@ overrides_of_model_config:
|
|||||||
overrides_of_model_kwargs:
|
overrides_of_model_kwargs:
|
||||||
# use_cache: False
|
# use_cache: False
|
||||||
|
|
||||||
# optional overrides to the bnb 4bit quantization configuration
|
|
||||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
|
||||||
bnb_config_kwargs:
|
|
||||||
# These are default values
|
|
||||||
llm_int8_has_fp16_weight: false
|
|
||||||
bnb_4bit_quant_type: nf4
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
|
|
||||||
|
|
||||||
|
# Quantization configuration.
|
||||||
|
quantization:
|
||||||
|
backend: bnb | hqq | gptq
|
||||||
|
bits: 8
|
||||||
|
# optional overrides to the bnb 4bit quantization configuration
|
||||||
|
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||||
|
bnb_config_kwargs:
|
||||||
|
# These are default values
|
||||||
|
llm_int8_has_fp16_weight: false
|
||||||
|
bnb_4bit_quant_type: nf4
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
|
||||||
|
hqq_config:
|
||||||
|
# pick one of the following, depending on if you want to uniformly quantize the whole model or
|
||||||
|
# apply different quantization settings to specific layers in the model:
|
||||||
|
|
||||||
|
# if uniformly quantize the whole model:
|
||||||
|
group_size: 64
|
||||||
|
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
|
||||||
|
- nbits: 4
|
||||||
|
group_size: 64
|
||||||
|
target_modules:
|
||||||
|
- self_attn.k_proj
|
||||||
|
- self_attn.v_proj
|
||||||
|
- self_attn.o_proj
|
||||||
|
- nbits: 3
|
||||||
|
group_size: 32
|
||||||
|
target_modules:
|
||||||
|
- mlp.gate_proj
|
||||||
|
- mlp.up_proj
|
||||||
|
- mlp.down_proj
|
||||||
|
|
||||||
|
# (Internal Use Only)
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq:
|
||||||
|
|
||||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||||
load_in_8bit: true
|
load_in_8bit:
|
||||||
# Use bitsandbytes 4 bit
|
# Use bitsandbytes 4 bit
|
||||||
load_in_4bit:
|
load_in_4bit:
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ hf_xet==1.0.0
|
|||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
hqq==0.2.5
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.23.3
|
||||||
|
|
||||||
|
|||||||
@@ -236,6 +236,18 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
if cfg.quantization:
|
||||||
|
if cfg.quantization.backend in ["bnb"]:
|
||||||
|
if cfg.quantization.bits == 8:
|
||||||
|
cfg.load_in_8bit = True
|
||||||
|
elif cfg.quantization.bits == 4:
|
||||||
|
cfg.load_in_4bit = True
|
||||||
|
|
||||||
|
if cfg.quantization.backend == "gptq":
|
||||||
|
cfg.gptq = True
|
||||||
|
elif cfg.quantization.backend == "hqq":
|
||||||
|
cfg.hqq = True
|
||||||
|
|
||||||
|
|
||||||
def normalize_cfg_datasets(cfg):
|
def normalize_cfg_datasets(cfg):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
HqqConfig,
|
||||||
Llama4ForConditionalGeneration,
|
Llama4ForConditionalGeneration,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
Mistral3ForConditionalGeneration,
|
||||||
@@ -833,6 +834,13 @@ class ModelLoader:
|
|||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
def set_quantization_config(self) -> None:
|
def set_quantization_config(self) -> None:
|
||||||
|
if (
|
||||||
|
(not self.cfg.quantization)
|
||||||
|
and (not self.cfg.load_in_8bit)
|
||||||
|
and (not self.cfg.load_in_4bit)
|
||||||
|
and not self.cfg.gptq
|
||||||
|
):
|
||||||
|
return
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
|
|
||||||
@@ -854,21 +862,21 @@ class ModelLoader:
|
|||||||
and hasattr(self.model_config, "quantization_config")
|
and hasattr(self.model_config, "quantization_config")
|
||||||
and self.model_config.quantization_config["quant_method"]
|
and self.model_config.quantization_config["quant_method"]
|
||||||
in ["gptq", "awq", "bitsandbytes"]
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
|
and not self.cfg.hqq
|
||||||
):
|
):
|
||||||
if self.model_config.quantization_config["quant_method"] == "gptq":
|
quant_config_class_dict = {
|
||||||
self.model_kwargs["quantization_config"] = GPTQConfig(
|
"gptq": GPTQConfig,
|
||||||
**self.model_config.quantization_config
|
"awq": AwqConfig,
|
||||||
)
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
elif self.model_config.quantization_config["quant_method"] == "awq":
|
}
|
||||||
self.model_kwargs["quantization_config"] = AwqConfig(
|
|
||||||
**self.model_config.quantization_config
|
quant_config_class = quant_config_class_dict[
|
||||||
)
|
self.model_config.quantization_config["quant_method"]
|
||||||
elif (
|
]
|
||||||
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
|
self.model_kwargs["quantization_config"] = quant_config_class(
|
||||||
):
|
**self.model_config.quantization_config
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
)
|
||||||
**self.model_config.quantization_config
|
|
||||||
)
|
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -886,8 +894,8 @@ class ModelLoader:
|
|||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|
||||||
if self.cfg.bnb_config_kwargs:
|
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
@@ -903,6 +911,13 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.hqq:
|
||||||
|
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
||||||
|
|
||||||
|
self.model_kwargs["quantization_config"] = HqqConfig(
|
||||||
|
**get_hqq_quant_config_kwargs(self.cfg)
|
||||||
|
)
|
||||||
|
|
||||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||||
self.model_kwargs.pop("load_in_8bit", None)
|
self.model_kwargs.pop("load_in_8bit", None)
|
||||||
@@ -1036,6 +1051,12 @@ class ModelLoader:
|
|||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.cfg.hqq and torch.cuda.device_count() < 2:
|
||||||
|
# for some reason on single gpu, we need to set device_map to auto/cuda
|
||||||
|
# otherwise you run into tensors on two devices error during training
|
||||||
|
# Doesn't affect multi-gpu tho
|
||||||
|
|
||||||
|
self.model_kwargs["device_map"] = "auto"
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -1190,7 +1211,7 @@ class ModelLoader:
|
|||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
|
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
self.model = prepare_model_for_kbit_training(
|
self.model = prepare_model_for_kbit_training(
|
||||||
@@ -1460,7 +1481,16 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
from hqq.core.peft import HQQLinearLoRA
|
||||||
|
from hqq.core.quantize import HQQLinear
|
||||||
|
|
||||||
|
cls = (
|
||||||
|
bnb.nn.Linear4bit,
|
||||||
|
bnb.nn.Linear8bitLt,
|
||||||
|
torch.nn.Linear,
|
||||||
|
HQQLinear,
|
||||||
|
HQQLinearLoRA,
|
||||||
|
)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.quant import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
@@ -23,8 +23,11 @@ class PeftConfig(BaseModel):
|
|||||||
class LoraConfig(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
"""Peft / LoRA configuration subset"""
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
load_in_8bit: bool | None = Field(default=False)
|
quantization: QuantizationConfig | None = None
|
||||||
load_in_4bit: bool | None = Field(default=False)
|
load_in_4bit: bool | None = None # for internal use
|
||||||
|
load_in_8bit: bool | None = None # for internal use
|
||||||
|
hqq: bool | None = None # for internal use
|
||||||
|
gptq: bool | None = None # for internal use
|
||||||
|
|
||||||
adapter: str | None = None
|
adapter: str | None = None
|
||||||
lora_model_dir: str | None = None
|
lora_model_dir: str | None = None
|
||||||
@@ -50,8 +53,6 @@ class LoraConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_on_cpu: bool | None = None
|
lora_on_cpu: bool | None = None
|
||||||
gptq: bool | None = None
|
|
||||||
bnb_config_kwargs: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
loraplus_lr_ratio: float | None = Field(
|
loraplus_lr_ratio: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -74,11 +75,11 @@ class LoraConfig(BaseModel):
|
|||||||
if (
|
if (
|
||||||
not data.get("adapter")
|
not data.get("adapter")
|
||||||
and not data.get("inference")
|
and not data.get("inference")
|
||||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
and (data.get("quantization"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
"Quantization is not supported without setting an adapter."
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off Quantization."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -86,25 +87,26 @@ class LoraConfig(BaseModel):
|
|||||||
def validate_qlora(self):
|
def validate_qlora(self):
|
||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
if self.quantization.bits == 8 or self.load_in_8bit:
|
||||||
if self.load_in_8bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
if self.gptq:
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
raise ValueError("Can't merge qlora if using gptq")
|
||||||
|
|
||||||
if self.load_in_4bit:
|
if self.quantization.bits == 4 or self.load_in_4bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.load_in_8bit:
|
if self.quantization:
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
if self.quantization.bits == 8 or self.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
if self.gptq:
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't load qlora if gptq")
|
raise ValueError("Can't load qlora if using gptq")
|
||||||
|
|
||||||
|
if not self.quantization.bits == 4 or self.load_in_4bit:
|
||||||
|
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||||
|
|
||||||
if not self.load_in_4bit:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
@field_validator("loraplus_lr_embedding")
|
||||||
@@ -121,6 +123,24 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_hqq(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("quantization")
|
||||||
|
and data.get("quantization").get("backend") == "hqq"
|
||||||
|
):
|
||||||
|
if not data.get("quantization").get("hqq_config"):
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ, must set `hqq_config` under `quantization`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get("load_in_4bit") or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
93
src/axolotl/utils/schemas/quant.py
Normal file
93
src/axolotl/utils/schemas/quant.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
""" "
|
||||||
|
Takes care of quantization configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
from annotated_types import MinLen
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class HQQConfig(BaseModel):
|
||||||
|
"""HQQ configuration subset"""
|
||||||
|
|
||||||
|
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
group_size: int = Field(default=64)
|
||||||
|
target_modules: list[str] | str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationConfig(BaseModel):
|
||||||
|
"""Over all Quantization configuration subset"""
|
||||||
|
|
||||||
|
# We will use this class as base future refactoring of all quantization configs
|
||||||
|
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
||||||
|
bits: Literal[8, 4, 3, 2, 1] | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_hqq_config(cls, data):
|
||||||
|
if data.get("backend") == "hqq" and not data.get("hqq_config"):
|
||||||
|
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
|
||||||
|
|
||||||
|
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
||||||
|
for hqq_config in data.get("hqq_config"):
|
||||||
|
if hqq_config.get("target_modules") is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For list of hqq configs, `target_modules` must be specified for each"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_hqq_quant_config_kwargs(cfg):
|
||||||
|
|
||||||
|
# If no target module is specified, then target the whole model
|
||||||
|
if not isinstance(cfg.quantization.hqq_config, list):
|
||||||
|
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(cfg.quantization.hqq_config) == 1
|
||||||
|
and cfg.quantization.hqq_config[0].target_modules is None
|
||||||
|
):
|
||||||
|
|
||||||
|
nbits = (
|
||||||
|
cfg.quantization.hqq_config[0].nbits
|
||||||
|
if cfg.quantization.hqq_config[0].nbits is not None
|
||||||
|
else cfg.quantization.bits
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nbits": nbits,
|
||||||
|
"group_size": cfg.quantization.hqq_config[0].group_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
||||||
|
for hqq_config in cfg.quantization.hqq_config:
|
||||||
|
nbits = (
|
||||||
|
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
|
||||||
|
)
|
||||||
|
|
||||||
|
target_modules = hqq_config.target_modules
|
||||||
|
if not isinstance(target_modules, list):
|
||||||
|
target_modules = [target_modules]
|
||||||
|
|
||||||
|
for module in target_modules:
|
||||||
|
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
||||||
|
"nbits": nbits,
|
||||||
|
"group_size": hqq_config.group_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
return hqq_quant_config_kwargs
|
||||||
@@ -30,8 +30,10 @@ class TestMultiGPUEval:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"load_in_8bit": False,
|
"quantization": {
|
||||||
"load_in_4bit": True,
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"strict": False,
|
"strict": False,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -99,8 +101,10 @@ class TestMultiGPUEval:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"load_in_8bit": False,
|
"quantization": {
|
||||||
"load_in_4bit": True,
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"strict": False,
|
"strict": False,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
|
|||||||
@@ -171,7 +171,10 @@ class TestMultiGPULlama:
|
|||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -249,7 +252,10 @@ class TestMultiGPULlama:
|
|||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -548,7 +554,10 @@ class TestMultiGPULlama:
|
|||||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"mean_resizing_embeddings": True,
|
"mean_resizing_embeddings": True,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
@@ -648,7 +657,10 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
@@ -722,7 +734,10 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
@@ -796,7 +811,10 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ class TestMultiGPUQwen2:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": base_model,
|
"base_model": base_model,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
|
|||||||
@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
},
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"gptq": True,
|
"gptq": True,
|
||||||
|
|||||||
@@ -33,7 +33,10 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 4,
|
"lora_r": 4,
|
||||||
"lora_alpha": 8,
|
"lora_alpha": 8,
|
||||||
@@ -91,7 +94,10 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 4,
|
"lora_r": 4,
|
||||||
"lora_alpha": 8,
|
"lora_alpha": 8,
|
||||||
|
|||||||
141
tests/e2e/test_quantization.py
Normal file
141
tests/e2e/test_quantization.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for training with quantized model
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHQQ(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test cases for training of HQQ-quantized llama models"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_hqq_lora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"use_hqq": True,
|
||||||
|
"hqq_config": [
|
||||||
|
{
|
||||||
|
"nbits": 8,
|
||||||
|
"group_size": 64,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_hqq_qlora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"use_hqq": True,
|
||||||
|
"hqq_config": [
|
||||||
|
{
|
||||||
|
"nbits": 4,
|
||||||
|
"group_size": 64,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if is_torch_bf16_gpu_available():
|
||||||
|
cfg.bf16 = True
|
||||||
|
else:
|
||||||
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
|
# "load_in_4bit": True
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "",
|
"deepspeed": "",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": None,
|
"deepspeed": None,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"gptq": True,
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": False,
|
"quantization": {
|
||||||
|
"bits": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
with pytest.raises(ValueError, match=r".*bits <= 4*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"gptq": True,
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"bits": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
match=r"Quantization is not supported without setting an adapter.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
}
|
"bits": 4,
|
||||||
)
|
},
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,8 +21,10 @@ class TestModelsUtils:
|
|||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
"model_type": "LlamaForCausalLM",
|
"model_type": "LlamaForCausalLM",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
"load_in_4bit": False,
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user