Compare commits

..

3 Commits

Author SHA1 Message Date
NanoCode012
970b2a6f2f feat: test for config validation and BC for new peft weight dtype 2026-02-16 21:26:04 +07:00
NanoCode012
1f7f5e7c26 feat: handle lora kernels compat with torchao 2026-02-16 21:25:50 +07:00
NanoCode012
60c0a828cc feat: add torchao's int4, nf4, int8 2026-02-16 21:25:24 +07:00
24 changed files with 507 additions and 78 deletions

View File

@@ -2,21 +2,21 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
liger-kernel==0.6.4
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers @ git+https://github.com/winglian/transformers.git@refactor-inner-training-loop-reorder-only
transformers==5.0.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
trl==0.27.1
hf_xet==1.2.0
kernels==0.11.5
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.16.0
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1

View File

@@ -246,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -52,8 +53,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
@@ -134,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length")
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -154,8 +157,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self, model=None):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer(model=model)
return super().create_optimizer()
opt_model = self.model if model is None else model
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer

View File

@@ -15,7 +15,7 @@ from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .quantize import dequantize_weight
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
@@ -46,6 +46,12 @@ def get_lora_parameters(
W = base_layer.weight
b = base_layer.bias
# Unwrap DTensor if FSDP2 left the weight wrapped -- DTensor does not proxy
# attribute access to the underlying tensor subclass, so torchao methods like
# .dequantize() or .get_original_weight() would not be visible.
if isinstance(W, DTensor):
W = W.full_tensor()
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
@@ -86,6 +92,7 @@ def matmul_lora(
B: torch.Tensor | None,
s: float | None,
out: torch.Tensor | None = None,
transpose: bool = True,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
@@ -98,12 +105,15 @@ def matmul_lora(
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
transpose: If True (default), transpose W before matmul (forward path).
Set to False for backward paths where W is already in the correct layout.
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
is_quantized = W_quant is not None or type(W) is not torch.Tensor
W = dequantize_weight(W, W_quant, transpose=transpose)
reshape = False
if X.dim() == 3:
@@ -112,7 +122,7 @@ def matmul_lora(
reshape = True
out = torch.matmul(X, W, out=out)
if W_quant is not None:
if is_quantized:
del W
if A is not None:
@@ -292,15 +302,16 @@ class LoRA_MLP(torch.autograd.Function):
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
# Down projection (backward: no transpose needed, W is already [out, in])
grad_down = matmul_lora(
grad_output,
down_weight.t(),
down_weight,
None,
down_quant,
down_B,
down_A,
down_scale,
transpose=False,
)
# Activation backward
@@ -332,7 +343,7 @@ class LoRA_MLP(torch.autograd.Function):
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
up_weight = dequantize_weight(up_weight, up_quant, transpose=True)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
@@ -344,7 +355,7 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant)
gate_weight = dequantize_weight(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
del gate_weight
@@ -631,7 +642,7 @@ class LoRA_QKV(torch.autograd.Function):
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
q_weight_t = dequantize_weight(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
@@ -639,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
k_weight_t = dequantize_weight(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
@@ -647,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
v_weight_t = dequantize_weight(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
@@ -810,7 +821,7 @@ class LoRA_O(torch.autograd.Function):
d_B = s * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
W = dequantize_weight(W, W_quant, transpose=True)
dX = dY @ W.t()
del W

View File

@@ -146,3 +146,43 @@ def dequantize(
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out
def dequantize_weight(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
transpose: bool = False,
) -> torch.Tensor:
"""Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in ``dequantize``.
Args:
W: Quantized weight tensor ``[out_features, in_features]``.
quant_state: bnb ``QuantState`` (None for torchao / unquantized).
transpose: If True, return ``[in_features, out_features]``.
Returns:
Dequantized float tensor, optionally transposed.
"""
# torchao path: tensor subclass with embedded quantization state
if quant_state is None and type(W) is not torch.Tensor:
result = None
# NF4Tensor (check first — NF4Tensor.dequantize is a static method)
if hasattr(W, "get_original_weight"):
result = W.get_original_weight()
else:
# AffineQuantizedTensor (INT4, etc.)
try:
result = W.dequantize()
except (TypeError, RuntimeError):
pass
if result is not None:
return result.t() if transpose else result
# bnb path: transpose input before the CUDA kernel (existing convention)
if transpose:
return dequantize(W.t(), quant_state)
return dequantize(W, quant_state)

View File

@@ -23,6 +23,7 @@ from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import TorchAOQuantDType
LOG = get_logger(__name__)
@@ -134,11 +135,13 @@ def load_lora(
rank = int(os.environ.get("LOCAL_RANK", 0))
is_torchao = cfg.peft and cfg.peft.backend == "torchao"
if (
cfg.fsdp_config
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_meta_for_peft(model)
@@ -146,6 +149,15 @@ def load_lora(
if cfg.peft_autocast_adapter_dtype is not None:
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
# Patch PEFT's torchao dispatch before any model creation/loading.
# Must happen before both get_peft_model and PeftModel.from_pretrained,
# as both trigger LoRA layer dispatch that would fail for INT4/NF4 weights.
# INT8 is natively supported by PEFT's TorchaoLoraLinear, so skip the patch.
if is_torchao and cfg.peft.weight_dtype != TorchAOQuantDType.int8:
from axolotl.monkeypatch.peft.utils import patch_peft_torchao_dispatch
patch_peft_torchao_dispatch()
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
if cfg.lora_on_cpu:
@@ -172,6 +184,7 @@ def load_lora(
and cfg.adapter
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
and not is_torchao
):
setup_quantized_peft_meta_for_training(model)

View File

@@ -158,6 +158,15 @@ class ModelLoader:
"""Property that determines if FSDP with QLoRA is enabled."""
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
@property
def is_torchao_qlora(self):
"""Property that determines if torchao backend is used for QLoRA."""
return (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
)
@send_errors
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
@@ -491,8 +500,9 @@ class ModelLoader:
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
if self.is_fsdp_enabled:
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
if self.is_qlora_and_fsdp_enabled:
# For QLoRA + FSDP with bnb, we still need to set device_map for proper initialization
# torchao tensors work natively with FSDP2, no device_map override needed
if self.is_qlora_and_fsdp_enabled and not self.is_torchao_qlora:
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
@@ -561,6 +571,44 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif (
self.cfg.adapter == "qlora"
and self.cfg.peft
and self.cfg.peft.backend == "torchao"
and not self.cfg.merge_lora
):
from transformers import TorchAoConfig
from axolotl.utils.schemas.enums import TorchAOQuantDType
weight_dtype = self.cfg.peft.weight_dtype
if weight_dtype == TorchAOQuantDType.int4:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int4_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.int8:
group_size = self.cfg.peft.group_size or 128
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type="int8_weight_only",
group_size=group_size,
)
elif weight_dtype == TorchAOQuantDType.nf4:
from torchao.dtypes._nf4tensor_api import NF4WeightOnlyConfig
block_size = self.cfg.peft.group_size or 64
self.model_kwargs["quantization_config"] = TorchAoConfig(
quant_type=NF4WeightOnlyConfig(
block_size=block_size,
scaler_block_size=256,
),
)
else:
raise ValueError(
f"Unsupported torchao weight_dtype for QLoRA: {weight_dtype}. "
"Supported: int4, int8, nf4"
)
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
@@ -860,6 +908,10 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
# torchao quantized models don't use Params4bit and don't need kbit preparation
if self.is_torchao_qlora:
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -10,7 +10,6 @@ from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
@@ -349,10 +348,12 @@ class PatchManager:
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
is_torchao = self.cfg.peft and self.cfg.peft.backend == "torchao"
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and not is_torchao
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
@@ -501,7 +502,6 @@ class PatchManager:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately

View File

@@ -59,12 +59,7 @@ class CPU_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad = True
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Newer HF models (e.g. Qwen3MoE) using GradientCheckpointingLayer
# return a plain tensor, not a tuple. Older models return tuples
# like (hidden_states, present_kv, ...). Unwrap if needed.
if isinstance(output, (tuple, list)):
(output,) = output
(output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, dY)
return (
None,

View File

@@ -78,3 +78,30 @@ def patch_peft_prep_code():
axolotl.loaders.model.prepare_model_for_kbit_training = (
fixed_prepare_model_for_kbit_training
)
def patch_peft_torchao_dispatch():
"""Skip PEFT's TorchaoLoraLinear for non-INT8 torchao weights.
PEFT's dispatch_torchao() matches AffineQuantizedTensor but then errors in
_check_dtype_supported() because it only allows INT8. Our LoRA kernels handle
dequantization explicitly, so we bypass PEFT's torchao dispatch entirely and
let it fall back to standard Linear LoRA layers.
"""
try:
from peft.tuners.lora import torchao as peft_torchao
except ImportError:
LOG.warning("Could not import peft.tuners.lora.torchao for patching")
return
if getattr(peft_torchao, "_axolotl_patched", False):
return
def patched_dispatch(target, adapter_name, lora_config, **kwargs):
# Return None so PEFT falls back to standard Linear LoRA layers.
# Our LoRA kernels handle torchao dequantization explicitly.
return None
peft_torchao.dispatch_torchao = patched_dispatch
peft_torchao._axolotl_patched = True
LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear")

View File

@@ -28,12 +28,8 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item()"
)
PATCHED_MAYBE_CODE = (
"tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).nanmean().item()"
)
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
def check_evaluation_loop_is_patchable() -> bool:

View File

@@ -446,16 +446,7 @@ class AxolotlInputConfig(
},
)
unfrozen_parameters: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "List of regex patterns for parameter names to keep unfrozen. "
"All other parameters will be frozen via requires_grad=False. "
"Note: range-based patterns (e.g. embed_tokens.weight$[:32000]) use gradient "
"zeroing rather than a true freeze, so weight decay will still apply to the "
"frozen portion and optimizer states are allocated for the full parameter."
},
)
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(
default=512,

View File

@@ -8,6 +8,7 @@ import torch
class TorchAOQuantDType(Enum):
int4 = torch.int4
int8 = torch.int8
nf4 = "nf4"
float8_e4m3fn = torch.float8_e4m3fn
nvfp4 = "nvfp4"
@@ -16,6 +17,8 @@ class TorchAOQuantDType(Enum):
return TorchAOQuantDType.int4
if str == "int8":
return TorchAOQuantDType.int8
if str == "nf4":
return TorchAOQuantDType.nf4
if str in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if str == "nvfp4":

View File

@@ -1,9 +1,12 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator, model_validator
from axolotl.utils.schemas.enums import TorchAOQuantDType
from axolotl.utils.schemas.quantization import validate_ao_dtype
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -15,7 +18,7 @@ class LoftQConfig(BaseModel):
class PeftConfig(BaseModel):
"""peftq configuration subset"""
"""PEFT configuration subset"""
loftq_config: LoftQConfig | None = Field(
default=None,
@@ -23,6 +26,29 @@ class PeftConfig(BaseModel):
"description": "Configuration options for loftq initialization for LoRA"
},
)
backend: Literal["bnb", "torchao"] | None = Field(
default=None,
json_schema_extra={
"description": "Quantization backend for QLoRA. 'bnb' for bitsandbytes (default), 'torchao' for torchao."
},
)
weight_dtype: TorchAOQuantDType | None = Field(
default=None,
json_schema_extra={
"description": "Weight quantization dtype (int4, int8, or nf4). Also used with bnb backend to auto-configure quantization."
},
)
group_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Group size for quantization. Defaults to 128 for int4, 64 for nf4."
},
)
@field_validator("weight_dtype", mode="before")
@classmethod
def validate_weight_dtype(cls, v):
return validate_ao_dtype(v)
class LoraConfig(BaseModel):
@@ -156,6 +182,56 @@ class LoraConfig(BaseModel):
merge_lora: bool | None = None
@model_validator(mode="before")
@classmethod
def auto_detect_qlora(cls, data):
"""Auto-set adapter type and quantization flags from peft config.
When peft.backend and peft.weight_dtype are set, this infers the correct
adapter type and internal flags (load_in_4bit, load_in_8bit) so users
don't need to set them manually.
"""
peft = data.get("peft")
if not isinstance(peft, dict):
return data
backend = peft.get("backend")
weight_dtype = peft.get("weight_dtype")
# Validate: weight_dtype requires backend
if weight_dtype and not backend:
raise ValueError(
"peft.backend is required when peft.weight_dtype is set. "
"Use 'torchao' or 'bnb'."
)
if not weight_dtype:
return data
adapter = data.get("adapter")
if backend == "torchao":
# torchao: any quantized weight_dtype means qlora
if adapter == "lora":
data["adapter"] = "qlora"
elif backend == "bnb":
if weight_dtype == "nf4":
# bnb nf4 = qlora with load_in_4bit
if adapter == "lora":
data["adapter"] = "qlora"
data.setdefault("load_in_4bit", True)
elif weight_dtype == "int8":
# bnb int8 = lora with load_in_8bit
data.setdefault("load_in_8bit", True)
else:
raise ValueError(
f"peft.weight_dtype '{weight_dtype}' is not supported with bnb backend. "
"Supported: nf4, int8."
)
return data
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
@@ -173,6 +249,8 @@ class LoraConfig(BaseModel):
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
is_torchao = self.peft and self.peft.backend == "torchao"
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
@@ -184,7 +262,20 @@ class LoraConfig(BaseModel):
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
elif is_torchao:
# torchao backend: validate torchao-specific requirements
if self.load_in_4bit or self.load_in_8bit:
raise ValueError(
"load_in_4bit/load_in_8bit are for bitsandbytes. "
"With peft.backend: torchao, quantization is handled by torchao."
)
if not self.peft.weight_dtype:
raise ValueError(
"peft.weight_dtype is required when peft.backend is 'torchao'"
)
else:
# Default bnb path
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")

View File

@@ -16,6 +16,8 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
return TorchAOQuantDType.int4
if v == "int8":
return TorchAOQuantDType.int8
if v == "nf4":
return TorchAOQuantDType.nf4
if v in ["float8_e4m3fn", "fp8", "float8"]:
return TorchAOQuantDType.float8_e4m3fn
if v == "nvfp4":

View File

@@ -300,6 +300,7 @@ class TestHFRLTrainerBuilder:
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)

View File

@@ -3,7 +3,7 @@
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
from axolotl.kernels.quantize import dequantize, dequantize_weight
def test_dequantize_null_state():
@@ -100,3 +100,18 @@ def test_dequantize_output_tensor():
result = dequantize(W, quant_state, out=out)
assert result is out
def test_dequantize_weight_plain_tensor():
"""Test that dequantize_weight passes through unquantized tensors unchanged"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=False)
assert torch.equal(result, W)
def test_dequantize_weight_plain_tensor_transpose():
"""Test that dequantize_weight transposes unquantized tensors"""
W = torch.randn(32, 64)
result = dequantize_weight(W, quant_state=None, transpose=True)
assert result.shape == (64, 32)
assert torch.equal(result, W.t())

View File

@@ -186,7 +186,6 @@ class TestFSDP1:
verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test, deprecate fsdp1 asap")
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{

View File

@@ -365,7 +365,6 @@ class TestFSDP2:
verify_training_success(temp_dir)
@pytest.mark.skip(reason="slow test w cu129 + torch 2.9.1 + py3.12")
@require_torch_2_7_0
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(

View File

@@ -115,9 +115,6 @@ class TestAssistantChatTemplateLlama3:
def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset")
assert "LlamaTokenizer" in phi35_tokenizer.__class__.__name__, (
"phi35 tokenizer should be a LlamaTokenizer"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
phi35_tokenizer,
@@ -143,13 +140,13 @@ class TestAssistantChatTemplateLlama3:
# fmt: off
expected_input_ids = [
32010, # user
12199, 32007, # user eot
22172, 32007, # user eot
32001, # assistant
12199, 32007, # assistant eot
22172, 32007, # assistant eot
32010, # user
16773, 26966, 32007, # user eot
1781, 26966, 32007, # user eot
32001, # assistant
16773, 26966, 32007, # assistant eot
1781, 26966, 32007, # assistant eot
]
expected_labels = [
-100, # user
@@ -159,7 +156,7 @@ class TestAssistantChatTemplateLlama3:
-100, # user
-100, -100, -100, # user eot
-100, # assistant
16773, 26966, 32007, # assistant eot
1781, 26966, 32007, # assistant eot
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")

View File

@@ -84,8 +84,7 @@ class TestTokenizers:
}
)
tokenizer = load_tokenizer(cfg)
assert "LlamaTokenizer" in tokenizer.__class__.__name__
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1792]
assert tokenizer("<|im_start|>user")["input_ids"] == [1, 32000, 1404]
assert len(tokenizer) == 32001
# ensure reloading the tokenizer again from cfg results in same vocab length

View File

@@ -3,6 +3,14 @@ import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
BASE_CFG = {
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
class TestLoRAConfigValidation:
"""Test suite for LoRA/QLoRA configuration validation"""
@@ -149,3 +157,195 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["lora_qkv_kernel"] is True
assert result["trust_remote_code"] is None
class TestTorchaoQLoRAConfigValidation:
"""Test suite for torchao QLoRA auto-detection and validation"""
# --- Auto-detection: torchao ---
@pytest.mark.parametrize("weight_dtype", ["int4", "int8", "nf4"])
def test_torchao_auto_detect_from_lora(self, weight_dtype):
"""adapter: lora + peft.backend: torchao auto-upgrades to qlora"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "torchao", "weight_dtype": weight_dtype},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["peft"]["backend"] == "torchao"
def test_torchao_explicit_qlora(self):
"""adapter: qlora + peft.backend: torchao works directly"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
# --- Auto-detection: bnb ---
def test_bnb_nf4_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: nf4 → qlora + load_in_4bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_bnb_int8_auto_detect_from_lora(self):
"""adapter: lora + peft.backend: bnb + weight_dtype: int8 → lora + load_in_8bit"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int8"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_bnb_nf4_explicit_qlora_auto_sets_load_in_4bit(self):
"""adapter: qlora + peft.backend: bnb + weight_dtype: nf4 auto-sets load_in_4bit"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
# --- Backward compat ---
def test_old_style_qlora_unchanged(self):
"""Old-style adapter: qlora + load_in_4bit: true still works"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True
def test_old_style_lora_8bit_unchanged(self):
"""Old-style adapter: lora + load_in_8bit: true still works"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_8bit": True,
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
assert result["load_in_8bit"] is True
def test_plain_lora_unchanged(self):
"""adapter: lora without peft block stays as lora"""
cfg = DictDefault(
{
"adapter": "lora",
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "lora"
# --- Validation errors ---
def test_torchao_with_load_in_4bit_errors(self):
"""peft.backend: torchao + load_in_4bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_4bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_with_load_in_8bit_errors(self):
"""peft.backend: torchao + load_in_8bit is a conflict"""
cfg = DictDefault(
{
"adapter": "qlora",
"load_in_8bit": True,
"peft": {"backend": "torchao", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"):
validate_config(cfg)
def test_torchao_without_weight_dtype_errors(self):
"""peft.backend: torchao without weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "qlora",
"peft": {"backend": "torchao"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.weight_dtype is required"):
validate_config(cfg)
def test_weight_dtype_without_backend_errors(self):
"""peft.weight_dtype without peft.backend errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="peft.backend is required"):
validate_config(cfg)
def test_bnb_unsupported_weight_dtype_errors(self):
"""peft.backend: bnb + unsupported weight_dtype errors"""
cfg = DictDefault(
{
"adapter": "lora",
"peft": {"backend": "bnb", "weight_dtype": "int4"},
**BASE_CFG,
}
)
with pytest.raises(ValueError, match="not supported with bnb"):
validate_config(cfg)
# --- Redundant flags don't conflict ---
def test_bnb_nf4_with_explicit_load_in_4bit(self):
"""peft.backend: bnb + weight_dtype: nf4 + load_in_4bit: true is fine (redundant)"""
cfg = DictDefault(
{
"adapter": "lora",
"load_in_4bit": True,
"peft": {"backend": "bnb", "weight_dtype": "nf4"},
**BASE_CFG,
}
)
result = validate_config(cfg)
assert result["adapter"] == "qlora"
assert result["load_in_4bit"] is True