Compare commits
4 Commits
feat_hqq
...
smaller-ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0670abc94 | ||
|
|
08f287b57f | ||
|
|
b4c7d9c29d | ||
|
|
d2637fb01d |
@@ -1,10 +1,13 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
# only run one test at a time so as not to OOM the GPU
|
||||||
|
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
|
||||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||||
pytest -v -n2 \
|
pytest -v -n2 \
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
|
||||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||||
--cov=axolotl \
|
--cov=axolotl \
|
||||||
--cov-report=xml:multigpu-coverage.xml
|
--cov-report=xml:multigpu-coverage.xml
|
||||||
@@ -14,11 +17,6 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
|
|||||||
--cov-append \
|
--cov-append \
|
||||||
--cov-report=xml:multigpu-coverage.xml
|
--cov-report=xml:multigpu-coverage.xml
|
||||||
|
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
|
||||||
--cov=axolotl \
|
|
||||||
--cov-append \
|
|
||||||
--cov-report=xml:multigpu-coverage.xml
|
|
||||||
|
|
||||||
# Upload coverage to Codecov
|
# Upload coverage to Codecov
|
||||||
if [ -f multigpu-coverage.xml ]; then
|
if [ -f multigpu-coverage.xml ]; then
|
||||||
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
||||||
|
|||||||
@@ -55,46 +55,20 @@ overrides_of_model_config:
|
|||||||
overrides_of_model_kwargs:
|
overrides_of_model_kwargs:
|
||||||
# use_cache: False
|
# use_cache: False
|
||||||
|
|
||||||
|
# optional overrides to the bnb 4bit quantization configuration
|
||||||
|
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||||
|
bnb_config_kwargs:
|
||||||
|
# These are default values
|
||||||
|
llm_int8_has_fp16_weight: false
|
||||||
|
bnb_4bit_quant_type: nf4
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
|
||||||
# Quantization configuration.
|
|
||||||
quantization:
|
|
||||||
backend: bnb | hqq | gptq
|
|
||||||
bits: 8
|
|
||||||
# optional overrides to the bnb 4bit quantization configuration
|
|
||||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
|
||||||
bnb_config_kwargs:
|
|
||||||
# These are default values
|
|
||||||
llm_int8_has_fp16_weight: false
|
|
||||||
bnb_4bit_quant_type: nf4
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
|
|
||||||
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
|
|
||||||
hqq_config:
|
|
||||||
# pick one of the following, depending on if you want to uniformly quantize the whole model or
|
|
||||||
# apply different quantization settings to specific layers in the model:
|
|
||||||
|
|
||||||
# if uniformly quantize the whole model:
|
|
||||||
group_size: 64
|
|
||||||
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
|
|
||||||
- nbits: 4
|
|
||||||
group_size: 64
|
|
||||||
target_modules:
|
|
||||||
- self_attn.k_proj
|
|
||||||
- self_attn.v_proj
|
|
||||||
- self_attn.o_proj
|
|
||||||
- nbits: 3
|
|
||||||
group_size: 32
|
|
||||||
target_modules:
|
|
||||||
- mlp.gate_proj
|
|
||||||
- mlp.up_proj
|
|
||||||
- mlp.down_proj
|
|
||||||
|
|
||||||
# (Internal Use Only)
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq:
|
gptq: true
|
||||||
|
|
||||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||||
load_in_8bit:
|
load_in_8bit: true
|
||||||
# Use bitsandbytes 4 bit
|
# Use bitsandbytes 4 bit
|
||||||
load_in_4bit:
|
load_in_4bit:
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ triton>=3.0.0
|
|||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.8
|
liger-kernel==0.5.6
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
@@ -22,7 +22,6 @@ hf_xet==1.0.0
|
|||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
hqq==0.2.5
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.23.3
|
||||||
|
|
||||||
|
|||||||
@@ -1040,11 +1040,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dataset_processes:
|
if self.cfg.dataset_processes:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||||
elif self.cfg.rl_beta is not None:
|
if self.cfg.orpo_alpha:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
|
||||||
elif self.cfg.orpo_alpha is not None:
|
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
|
|||||||
- deepseek_v2
|
- deepseek_v2
|
||||||
- gemma
|
- gemma
|
||||||
- gemma2
|
- gemma2
|
||||||
- gemma3
|
- gemma3 (partial support, no support for FLCE yet)
|
||||||
- granite
|
- granite
|
||||||
- jamba
|
- jamba
|
||||||
- llama
|
- llama
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ class LigerPlugin(BasePlugin):
|
|||||||
)
|
)
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
@@ -139,6 +141,38 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
|
||||||
|
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||||
|
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||||
|
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||||
|
_liger_rms_norm_wrapper,
|
||||||
|
offset=1.0,
|
||||||
|
casting_mode="gemma",
|
||||||
|
init_fn="zeros",
|
||||||
|
in_place=False,
|
||||||
|
)
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||||
|
)
|
||||||
elif cfg.model_config_type == "llama4":
|
elif cfg.model_config_type == "llama4":
|
||||||
from axolotl.integrations.liger.models.llama4 import (
|
from axolotl.integrations.liger.models.llama4 import (
|
||||||
apply_liger_kernel_to_llama4,
|
apply_liger_kernel_to_llama4,
|
||||||
|
|||||||
@@ -236,18 +236,6 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
if cfg.quantization:
|
|
||||||
if cfg.quantization.backend in ["bnb"]:
|
|
||||||
if cfg.quantization.bits == 8:
|
|
||||||
cfg.load_in_8bit = True
|
|
||||||
elif cfg.quantization.bits == 4:
|
|
||||||
cfg.load_in_4bit = True
|
|
||||||
|
|
||||||
if cfg.quantization.backend == "gptq":
|
|
||||||
cfg.gptq = True
|
|
||||||
elif cfg.quantization.backend == "hqq":
|
|
||||||
cfg.hqq = True
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_cfg_datasets(cfg):
|
def normalize_cfg_datasets(cfg):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -118,27 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
|||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
|
iter_ds = load_dataset(
|
||||||
# other ranks, we just need to present a fake dataset
|
path, streaming=True, split=split, name=name, data_files=data_files
|
||||||
if (
|
)
|
||||||
cfg.accelerator_config
|
|
||||||
and cfg.accelerator_config.dispatch_batches
|
|
||||||
and not is_local_main_process()
|
|
||||||
):
|
|
||||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
|
|
||||||
f.write("text\n")
|
|
||||||
f.write("lorem ipsum dolor sit amet\n")
|
|
||||||
# rewind the file pointer to the beginning so we can read it again
|
|
||||||
f.seek(0)
|
|
||||||
iter_ds = load_dataset(
|
|
||||||
"csv", data_files=f.name, split="train", streaming=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if is_local_main_process():
|
|
||||||
iter_ds = load_dataset(
|
|
||||||
path, streaming=True, split=split, name=name, data_files=data_files
|
|
||||||
)
|
|
||||||
|
|
||||||
if skip:
|
if skip:
|
||||||
LOG.info(f"Skipping {skip} samples from the dataset")
|
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||||
iter_ds = iter_ds.skip(skip)
|
iter_ds = iter_ds.skip(skip)
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from transformers import (
|
|||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
HqqConfig,
|
|
||||||
Llama4ForConditionalGeneration,
|
Llama4ForConditionalGeneration,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
Mistral3ForConditionalGeneration,
|
||||||
@@ -834,13 +833,6 @@ class ModelLoader:
|
|||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
def set_quantization_config(self) -> None:
|
def set_quantization_config(self) -> None:
|
||||||
if (
|
|
||||||
(not self.cfg.quantization)
|
|
||||||
and (not self.cfg.load_in_8bit)
|
|
||||||
and (not self.cfg.load_in_4bit)
|
|
||||||
and not self.cfg.gptq
|
|
||||||
):
|
|
||||||
return
|
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
|
|
||||||
@@ -862,21 +854,21 @@ class ModelLoader:
|
|||||||
and hasattr(self.model_config, "quantization_config")
|
and hasattr(self.model_config, "quantization_config")
|
||||||
and self.model_config.quantization_config["quant_method"]
|
and self.model_config.quantization_config["quant_method"]
|
||||||
in ["gptq", "awq", "bitsandbytes"]
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
and not self.cfg.hqq
|
|
||||||
):
|
):
|
||||||
quant_config_class_dict = {
|
if self.model_config.quantization_config["quant_method"] == "gptq":
|
||||||
"gptq": GPTQConfig,
|
self.model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
"awq": AwqConfig,
|
**self.model_config.quantization_config
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
)
|
||||||
}
|
elif self.model_config.quantization_config["quant_method"] == "awq":
|
||||||
|
self.model_kwargs["quantization_config"] = AwqConfig(
|
||||||
quant_config_class = quant_config_class_dict[
|
**self.model_config.quantization_config
|
||||||
self.model_config.quantization_config["quant_method"]
|
)
|
||||||
]
|
elif (
|
||||||
self.model_kwargs["quantization_config"] = quant_config_class(
|
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
|
||||||
**self.model_config.quantization_config
|
):
|
||||||
)
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**self.model_config.quantization_config
|
||||||
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -894,8 +886,8 @@ class ModelLoader:
|
|||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|
||||||
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
|
if self.cfg.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
@@ -911,13 +903,6 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.hqq:
|
|
||||||
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
|
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = HqqConfig(
|
|
||||||
**get_hqq_quant_config_kwargs(self.cfg)
|
|
||||||
)
|
|
||||||
|
|
||||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||||
self.model_kwargs.pop("load_in_8bit", None)
|
self.model_kwargs.pop("load_in_8bit", None)
|
||||||
@@ -1051,12 +1036,6 @@ class ModelLoader:
|
|||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.cfg.hqq and torch.cuda.device_count() < 2:
|
|
||||||
# for some reason on single gpu, we need to set device_map to auto/cuda
|
|
||||||
# otherwise you run into tensors on two devices error during training
|
|
||||||
# Doesn't affect multi-gpu tho
|
|
||||||
|
|
||||||
self.model_kwargs["device_map"] = "auto"
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -1211,7 +1190,7 @@ class ModelLoader:
|
|||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
|
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
self.model = prepare_model_for_kbit_training(
|
self.model = prepare_model_for_kbit_training(
|
||||||
@@ -1481,16 +1460,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
from hqq.core.peft import HQQLinearLoRA
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
from hqq.core.quantize import HQQLinear
|
|
||||||
|
|
||||||
cls = (
|
|
||||||
bnb.nn.Linear4bit,
|
|
||||||
bnb.nn.Linear8bitLt,
|
|
||||||
torch.nn.Linear,
|
|
||||||
HQQLinear,
|
|
||||||
HQQLinearLoRA,
|
|
||||||
)
|
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
|
|||||||
self.max_lr = max_lr
|
self.max_lr = max_lr
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
self.num_warmup_steps = num_warmup_steps
|
self.num_warmup_steps = num_warmup_steps
|
||||||
self.last_step = max(last_step - 1, 0)
|
self.last_step = last_step - 1
|
||||||
|
|
||||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
|
|||||||
@@ -660,7 +660,6 @@ class AxolotlInputConfig(
|
|||||||
data.get("val_set_size") == 0
|
data.get("val_set_size") == 0
|
||||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
and data.get("eval_strategy") != "no"
|
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from typing import Any
|
||||||
|
|
||||||
from axolotl.utils.schemas.quant import QuantizationConfig
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -23,11 +23,8 @@ class PeftConfig(BaseModel):
|
|||||||
class LoraConfig(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
"""Peft / LoRA configuration subset"""
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
quantization: QuantizationConfig | None = None
|
load_in_8bit: bool | None = Field(default=False)
|
||||||
load_in_4bit: bool | None = None # for internal use
|
load_in_4bit: bool | None = Field(default=False)
|
||||||
load_in_8bit: bool | None = None # for internal use
|
|
||||||
hqq: bool | None = None # for internal use
|
|
||||||
gptq: bool | None = None # for internal use
|
|
||||||
|
|
||||||
adapter: str | None = None
|
adapter: str | None = None
|
||||||
lora_model_dir: str | None = None
|
lora_model_dir: str | None = None
|
||||||
@@ -53,6 +50,8 @@ class LoraConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_on_cpu: bool | None = None
|
lora_on_cpu: bool | None = None
|
||||||
|
gptq: bool | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
loraplus_lr_ratio: float | None = Field(
|
loraplus_lr_ratio: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -75,11 +74,11 @@ class LoraConfig(BaseModel):
|
|||||||
if (
|
if (
|
||||||
not data.get("adapter")
|
not data.get("adapter")
|
||||||
and not data.get("inference")
|
and not data.get("inference")
|
||||||
and (data.get("quantization"))
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Quantization is not supported without setting an adapter."
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
"If you want to full finetune, please turn off Quantization."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -87,26 +86,25 @@ class LoraConfig(BaseModel):
|
|||||||
def validate_qlora(self):
|
def validate_qlora(self):
|
||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
if self.quantization.bits == 8 or self.load_in_8bit:
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if self.load_in_8bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.gptq:
|
||||||
raise ValueError("Can't merge qlora if using gptq")
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
if self.quantization.bits == 4 or self.load_in_4bit:
|
if self.load_in_4bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.quantization:
|
if self.load_in_8bit:
|
||||||
if self.quantization.bits == 8 or self.load_in_8bit:
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.gptq:
|
||||||
raise ValueError("Can't load qlora if using gptq")
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
if not self.quantization.bits == 4 or self.load_in_4bit:
|
|
||||||
raise ValueError("Require quantization.bits <= 4 for qlora")
|
|
||||||
|
|
||||||
|
if not self.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
@field_validator("loraplus_lr_embedding")
|
||||||
@@ -123,24 +121,6 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_hqq(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("quantization")
|
|
||||||
and data.get("quantization").get("backend") == "hqq"
|
|
||||||
):
|
|
||||||
if not data.get("quantization").get("hqq_config"):
|
|
||||||
raise ValueError(
|
|
||||||
"If using HQQ, must set `hqq_config` under `quantization`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.get("load_in_4bit") or data.get("load_in_8bit"):
|
|
||||||
raise ValueError(
|
|
||||||
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
""" "
|
|
||||||
Takes care of quantization configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Annotated, Any, Literal
|
|
||||||
|
|
||||||
from annotated_types import MinLen
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
|
|
||||||
|
|
||||||
class HQQConfig(BaseModel):
|
|
||||||
"""HQQ configuration subset"""
|
|
||||||
|
|
||||||
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
group_size: int = Field(default=64)
|
|
||||||
target_modules: list[str] | str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QuantizationConfig(BaseModel):
|
|
||||||
"""Over all Quantization configuration subset"""
|
|
||||||
|
|
||||||
# We will use this class as base future refactoring of all quantization configs
|
|
||||||
backend: Literal["bnb", "hqq", "gptq"] | None = None
|
|
||||||
bits: Literal[8, 4, 3, 2, 1] | None = None
|
|
||||||
bnb_config_kwargs: dict[str, Any] | None = None
|
|
||||||
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_hqq_config(cls, data):
|
|
||||||
if data.get("backend") == "hqq" and not data.get("hqq_config"):
|
|
||||||
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
|
|
||||||
|
|
||||||
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
|
|
||||||
for hqq_config in data.get("hqq_config"):
|
|
||||||
if hqq_config.get("target_modules") is None:
|
|
||||||
raise ValueError(
|
|
||||||
"For list of hqq configs, `target_modules` must be specified for each"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def get_hqq_quant_config_kwargs(cfg):
|
|
||||||
|
|
||||||
# If no target module is specified, then target the whole model
|
|
||||||
if not isinstance(cfg.quantization.hqq_config, list):
|
|
||||||
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
|
|
||||||
|
|
||||||
if (
|
|
||||||
len(cfg.quantization.hqq_config) == 1
|
|
||||||
and cfg.quantization.hqq_config[0].target_modules is None
|
|
||||||
):
|
|
||||||
|
|
||||||
nbits = (
|
|
||||||
cfg.quantization.hqq_config[0].nbits
|
|
||||||
if cfg.quantization.hqq_config[0].nbits is not None
|
|
||||||
else cfg.quantization.bits
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"nbits": nbits,
|
|
||||||
"group_size": cfg.quantization.hqq_config[0].group_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
hqq_quant_config_kwargs = {"dynamic_config": {}}
|
|
||||||
for hqq_config in cfg.quantization.hqq_config:
|
|
||||||
nbits = (
|
|
||||||
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
|
|
||||||
)
|
|
||||||
|
|
||||||
target_modules = hqq_config.target_modules
|
|
||||||
if not isinstance(target_modules, list):
|
|
||||||
target_modules = [target_modules]
|
|
||||||
|
|
||||||
for module in target_modules:
|
|
||||||
hqq_quant_config_kwargs["dynamic_config"][module] = {
|
|
||||||
"nbits": nbits,
|
|
||||||
"group_size": hqq_config.group_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
return hqq_quant_config_kwargs
|
|
||||||
@@ -193,14 +193,6 @@ def download_tiny_shakespeare_dataset():
|
|||||||
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_evolkit_kd_sample_dataset():
|
|
||||||
# download the dataset
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_deepseek_model_fixture():
|
def download_deepseek_model_fixture():
|
||||||
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||||
@@ -216,16 +208,6 @@ def download_huggyllama_model_fixture():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama33_70b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_1b_model_fixture():
|
def download_llama_1b_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
@@ -333,14 +315,6 @@ def download_llama2_model_fixture():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama32_1b_model_fixture():
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"osllmai-community/Llama-3.2-1B",
|
|
||||||
repo_type="model",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def tokenizer_huggyllama(
|
def tokenizer_huggyllama(
|
||||||
@@ -522,6 +496,12 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
|||||||
return datasets.load_from_disk(ds_path)["train"]
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_tiny_llama_7m_model():
|
||||||
|
# download the model
|
||||||
|
return snapshot_download_w_retry("axolotl-ai-internal/llama-7m", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
# # pylint: disable=redefined-outer-name,unused-argument
|
# # pylint: disable=redefined-outer-name,unused-argument
|
||||||
# def test_load_fixtures(
|
# def test_load_fixtures(
|
||||||
# download_smollm2_135m_model,
|
# download_smollm2_135m_model,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
# Tests under this directory should get run "solo" on their own as they
|
|
||||||
# seem to cause issues when run in the same batch as other tests.
|
|
||||||
|
|||||||
@@ -49,9 +49,8 @@ class TestPackedFlex:
|
|||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "tatsu-lab/alpaca",
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -90,5 +89,5 @@ class TestPackedFlex:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,10 +30,8 @@ class TestMultiGPUEval:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"quantization": {
|
"load_in_8bit": False,
|
||||||
"backend": "bnb",
|
"load_in_4bit": True,
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"strict": False,
|
"strict": False,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -101,10 +99,8 @@ class TestMultiGPUEval:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"quantization": {
|
"load_in_8bit": False,
|
||||||
"backend": "bnb",
|
"load_in_4bit": True,
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"strict": False,
|
"strict": False,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
|
|||||||
@@ -96,5 +96,5 @@ class TestMultiGPUGemma3:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 1.8, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -94,7 +94,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -105,7 +105,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
@@ -159,22 +159,19 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -247,15 +244,12 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
@@ -332,7 +326,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.01,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
@@ -391,7 +385,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -402,7 +396,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -463,7 +457,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_6_0
|
||||||
@@ -481,7 +475,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
@@ -544,7 +538,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.1, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
@@ -554,10 +548,7 @@ class TestMultiGPULlama:
|
|||||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"mean_resizing_embeddings": True,
|
"mean_resizing_embeddings": True,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
@@ -627,7 +618,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -657,16 +648,13 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -714,7 +702,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -734,16 +722,13 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -791,7 +776,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -811,16 +796,13 @@ class TestMultiGPULlama:
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -868,7 +850,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
@@ -878,7 +860,7 @@ class TestMultiGPULlama:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "axolotl-ai-internal/llama-7m",
|
||||||
"fix_untrained_tokens": True,
|
"fix_untrained_tokens": True,
|
||||||
"sequence_len": 512,
|
"sequence_len": 512,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
@@ -935,5 +917,5 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 4.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,10 +28,7 @@ class TestMultiGPUQwen2:
|
|||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": base_model,
|
"base_model": base_model,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class TestMultiGPURay:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_lt_2_6_0
|
@require_torch_lt_2_6_0
|
||||||
@@ -138,5 +138,5 @@ class TestMultiGPURay:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ...utils import check_tensorboard
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class TestSequenceParallelism:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.6, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -86,5 +86,5 @@ class TestFAXentropyLlama:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 1.5, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -32,10 +32,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
|
|||||||
@@ -89,9 +89,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"quantization": {
|
|
||||||
"backend": "gptq",
|
|
||||||
},
|
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"gptq": True,
|
"gptq": True,
|
||||||
|
|||||||
@@ -33,10 +33,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 16,
|
"lora_r": 16,
|
||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
|
|||||||
@@ -46,9 +46,8 @@ class TestResumeLlama:
|
|||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "tatsu-lab/alpaca",
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||||
@@ -130,7 +130,7 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -185,5 +185,5 @@ class TestUnslothQLoRA:
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,9 +41,8 @@ class TestPackedFlex(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "tatsu-lab/alpaca",
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -70,5 +69,5 @@ class TestPackedFlex(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,10 +34,7 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
|
|||||||
@@ -84,5 +84,5 @@ class TestPretrainLlama:
|
|||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/train_loss",
|
||||||
loss_threshold,
|
loss_threshold,
|
||||||
"Train Loss is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,10 +35,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 4,
|
"lora_r": 4,
|
||||||
"lora_alpha": 8,
|
"lora_alpha": 8,
|
||||||
@@ -94,10 +91,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 4,
|
"lora_r": 4,
|
||||||
"lora_alpha": 8,
|
"lora_alpha": 8,
|
||||||
|
|||||||
@@ -40,9 +40,8 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "tatsu-lab/alpaca",
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -69,5 +68,5 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,141 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for training with quantized model
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestHQQ(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test cases for training of HQQ-quantized llama models"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_hqq_lora(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"use_hqq": True,
|
|
||||||
"hqq_config": [
|
|
||||||
{
|
|
||||||
"nbits": 8,
|
|
||||||
"group_size": 64,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 16,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
|
||||||
)
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_hqq_qlora(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"use_hqq": True,
|
|
||||||
"hqq_config": [
|
|
||||||
{
|
|
||||||
"nbits": 4,
|
|
||||||
"group_size": 64,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 16,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
|
||||||
)
|
|
||||||
@@ -73,6 +73,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
|
temp_dir + "/runs", "train/train_loss", 2.5, "Train loss (%s) is too high"
|
||||||
)
|
)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -74,11 +74,7 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
# "load_in_4bit": True
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -97,10 +93,7 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "",
|
"deepspeed": "",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -114,10 +107,7 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": None,
|
"deepspeed": None,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -316,10 +306,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -331,9 +318,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"gptq": True,
|
||||||
"backend": "gptq",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -345,24 +330,19 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_4bit": False,
|
||||||
"bits": None,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*bits <= 4*"):
|
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -384,10 +364,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"backend": "bnb",
|
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -399,10 +376,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"gptq": True,
|
||||||
"backend": "gptq",
|
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -414,9 +388,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"bits": 4,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -1004,9 +976,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_4bit": True,
|
||||||
"bits": None,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -1014,16 +984,29 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r"Quantization is not supported without setting an adapter.*",
|
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"bits": 4,
|
}
|
||||||
},
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1035,9 +1018,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|||||||
|
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
|
|
||||||
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M"
|
MODEL_NAME = "axolotl-ai-internal/llama-7m"
|
||||||
|
|
||||||
|
|
||||||
@fixture()
|
@fixture()
|
||||||
@@ -36,7 +36,7 @@ One day, a little fish named Fin was swimming near the shore. He saw a big crab
|
|||||||
"""
|
"""
|
||||||
result = metric.compute(model, [sample_text])
|
result = metric.compute(model, [sample_text])
|
||||||
ppl = result["score"]
|
ppl = result["score"]
|
||||||
assert round(ppl, 2) == 7.41
|
assert round(ppl, 2) == 75.14
|
||||||
|
|
||||||
|
|
||||||
def test_perplexity_short(model, metric):
|
def test_perplexity_short(model, metric):
|
||||||
@@ -44,4 +44,4 @@ def test_perplexity_short(model, metric):
|
|||||||
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
|
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
|
||||||
result = metric.compute(model, [sample_text])
|
result = metric.compute(model, [sample_text])
|
||||||
ppl = result["score"]
|
ppl = result["score"]
|
||||||
assert round(ppl, 2) == 10.33
|
assert round(ppl, 2) == 70.54
|
||||||
|
|||||||
@@ -21,10 +21,8 @@ class TestModelsUtils:
|
|||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "JackFram/llama-68m",
|
||||||
"model_type": "LlamaForCausalLM",
|
"model_type": "LlamaForCausalLM",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"quantization": {
|
"load_in_8bit": True,
|
||||||
"backend": "bnb",
|
"load_in_4bit": False,
|
||||||
"bits": 8,
|
|
||||||
},
|
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"flash_attention": False,
|
"flash_attention": False,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user