From 60c0a828cc6cefa74bd5fa7464b78a6bb80ea8c4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Feb 2026 21:25:24 +0700 Subject: [PATCH] feat: add torchao's int4, nf4, int8 --- src/axolotl/loaders/adapter.py | 13 ++++ src/axolotl/loaders/model.py | 56 ++++++++++++- src/axolotl/loaders/patch_manager.py | 2 + src/axolotl/monkeypatch/peft/utils.py | 27 +++++++ src/axolotl/utils/schemas/enums.py | 3 + src/axolotl/utils/schemas/peft.py | 95 ++++++++++++++++++++++- src/axolotl/utils/schemas/quantization.py | 2 + 7 files changed, 194 insertions(+), 4 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 3b64b23db..d858c45f8 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -23,6 +23,7 @@ from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger +from axolotl.utils.schemas.enums import TorchAOQuantDType LOG = get_logger(__name__) @@ -134,11 +135,13 @@ def load_lora( rank = int(os.environ.get("LOCAL_RANK", 0)) + is_torchao = cfg.peft and cfg.peft.backend == "torchao" if ( cfg.fsdp_config and cfg.adapter and cfg.fsdp_config.cpu_ram_efficient_loading and rank != 0 + and not is_torchao ): setup_quantized_meta_for_peft(model) @@ -146,6 +149,15 @@ def load_lora( if cfg.peft_autocast_adapter_dtype is not None: model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype + # Patch PEFT's torchao dispatch before any model creation/loading. + # Must happen before both get_peft_model and PeftModel.from_pretrained, + # as both trigger LoRA layer dispatch that would fail for INT4/NF4 weights. + # INT8 is natively supported by PEFT's TorchaoLoraLinear, so skip the patch. + if is_torchao and cfg.peft.weight_dtype != TorchAOQuantDType.int8: + from axolotl.monkeypatch.peft.utils import patch_peft_torchao_dispatch + + patch_peft_torchao_dispatch() + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") if cfg.lora_on_cpu: @@ -172,6 +184,7 @@ def load_lora( and cfg.adapter and cfg.fsdp_config.cpu_ram_efficient_loading and rank != 0 + and not is_torchao ): setup_quantized_peft_meta_for_training(model) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6c8885526..3ff2116ea 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -158,6 +158,15 @@ class ModelLoader: """Property that determines if FSDP with QLoRA is enabled.""" return self.is_fsdp_enabled and self.cfg.adapter == "qlora" + @property + def is_torchao_qlora(self): + """Property that determines if torchao backend is used for QLoRA.""" + return ( + self.cfg.adapter == "qlora" + and self.cfg.peft + and self.cfg.peft.backend == "torchao" + ) + @send_errors def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]: """Load and prepare the model with all configurations and patches. @@ -491,8 +500,9 @@ class ModelLoader: # FSDP requires control over device placement, so don't set device_map when FSDP is enabled if self.is_fsdp_enabled: - # For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization - if self.is_qlora_and_fsdp_enabled: + # For QLoRA + FSDP with bnb, we still need to set device_map for proper initialization + # torchao tensors work natively with FSDP2, no device_map override needed + if self.is_qlora_and_fsdp_enabled and not self.is_torchao_qlora: self.model_kwargs["device_map"] = { "": int(os.environ.get("LOCAL_RANK", 0)) } @@ -561,6 +571,44 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) + elif ( + self.cfg.adapter == "qlora" + and self.cfg.peft + and self.cfg.peft.backend == "torchao" + and not self.cfg.merge_lora + ): + from transformers import TorchAoConfig + + from axolotl.utils.schemas.enums import TorchAOQuantDType + + weight_dtype = self.cfg.peft.weight_dtype + if weight_dtype == TorchAOQuantDType.int4: + group_size = self.cfg.peft.group_size or 128 + self.model_kwargs["quantization_config"] = TorchAoConfig( + quant_type="int4_weight_only", + group_size=group_size, + ) + elif weight_dtype == TorchAOQuantDType.int8: + group_size = self.cfg.peft.group_size or 128 + self.model_kwargs["quantization_config"] = TorchAoConfig( + quant_type="int8_weight_only", + group_size=group_size, + ) + elif weight_dtype == TorchAOQuantDType.nf4: + from torchao.dtypes._nf4tensor_api import NF4WeightOnlyConfig + + block_size = self.cfg.peft.group_size or 64 + self.model_kwargs["quantization_config"] = TorchAoConfig( + quant_type=NF4WeightOnlyConfig( + block_size=block_size, + scaler_block_size=256, + ), + ) + else: + raise ValueError( + f"Unsupported torchao weight_dtype for QLoRA: {weight_dtype}. " + "Supported: int4, int8, nf4" + ) elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit: bnb_config = { "load_in_4bit": True, @@ -860,6 +908,10 @@ class ModelLoader: # Make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + # torchao quantized models don't use Params4bit and don't need kbit preparation + if self.is_torchao_qlora: + skip_prepare_model_for_kbit_training = True + if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3cf8bbd20..dd84c883f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -348,10 +348,12 @@ class PatchManager: def _apply_fsdp2_bnb_patches(self): """Apply FSDP2 BNB patches.""" + is_torchao = self.cfg.peft and self.cfg.peft.backend == "torchao" if ( self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2" and self.cfg.adapter == "qlora" + and not is_torchao ): from axolotl.monkeypatch.fsdp2_qlora import ( apply_init_sharded_param_patch, diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index d1011f5eb..0e1adba64 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -78,3 +78,30 @@ def patch_peft_prep_code(): axolotl.loaders.model.prepare_model_for_kbit_training = ( fixed_prepare_model_for_kbit_training ) + + +def patch_peft_torchao_dispatch(): + """Skip PEFT's TorchaoLoraLinear for non-INT8 torchao weights. + + PEFT's dispatch_torchao() matches AffineQuantizedTensor but then errors in + _check_dtype_supported() because it only allows INT8. Our LoRA kernels handle + dequantization explicitly, so we bypass PEFT's torchao dispatch entirely and + let it fall back to standard Linear LoRA layers. + """ + try: + from peft.tuners.lora import torchao as peft_torchao + except ImportError: + LOG.warning("Could not import peft.tuners.lora.torchao for patching") + return + + if getattr(peft_torchao, "_axolotl_patched", False): + return + + def patched_dispatch(target, adapter_name, lora_config, **kwargs): + # Return None so PEFT falls back to standard Linear LoRA layers. + # Our LoRA kernels handle torchao dequantization explicitly. + return None + + peft_torchao.dispatch_torchao = patched_dispatch + peft_torchao._axolotl_patched = True + LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index b67888e0f..236b94f68 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -8,6 +8,7 @@ import torch class TorchAOQuantDType(Enum): int4 = torch.int4 int8 = torch.int8 + nf4 = "nf4" float8_e4m3fn = torch.float8_e4m3fn nvfp4 = "nvfp4" @@ -16,6 +17,8 @@ class TorchAOQuantDType(Enum): return TorchAOQuantDType.int4 if str == "int8": return TorchAOQuantDType.int8 + if str == "nf4": + return TorchAOQuantDType.nf4 if str in ["float8_e4m3fn", "fp8", "float8"]: return TorchAOQuantDType.float8_e4m3fn if str == "nvfp4": diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index a9ce1fbd6..6c86174b4 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -1,9 +1,12 @@ """Pydantic models for PEFT-related configuration""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator, model_validator +from axolotl.utils.schemas.enums import TorchAOQuantDType +from axolotl.utils.schemas.quantization import validate_ao_dtype + class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -15,7 +18,7 @@ class LoftQConfig(BaseModel): class PeftConfig(BaseModel): - """peftq configuration subset""" + """PEFT configuration subset""" loftq_config: LoftQConfig | None = Field( default=None, @@ -23,6 +26,29 @@ class PeftConfig(BaseModel): "description": "Configuration options for loftq initialization for LoRA" }, ) + backend: Literal["bnb", "torchao"] | None = Field( + default=None, + json_schema_extra={ + "description": "Quantization backend for QLoRA. 'bnb' for bitsandbytes (default), 'torchao' for torchao." + }, + ) + weight_dtype: TorchAOQuantDType | None = Field( + default=None, + json_schema_extra={ + "description": "Weight quantization dtype (int4, int8, or nf4). Also used with bnb backend to auto-configure quantization." + }, + ) + group_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Group size for quantization. Defaults to 128 for int4, 64 for nf4." + }, + ) + + @field_validator("weight_dtype", mode="before") + @classmethod + def validate_weight_dtype(cls, v): + return validate_ao_dtype(v) class LoraConfig(BaseModel): @@ -156,6 +182,56 @@ class LoraConfig(BaseModel): merge_lora: bool | None = None + @model_validator(mode="before") + @classmethod + def auto_detect_qlora(cls, data): + """Auto-set adapter type and quantization flags from peft config. + + When peft.backend and peft.weight_dtype are set, this infers the correct + adapter type and internal flags (load_in_4bit, load_in_8bit) so users + don't need to set them manually. + """ + peft = data.get("peft") + if not isinstance(peft, dict): + return data + + backend = peft.get("backend") + weight_dtype = peft.get("weight_dtype") + + # Validate: weight_dtype requires backend + if weight_dtype and not backend: + raise ValueError( + "peft.backend is required when peft.weight_dtype is set. " + "Use 'torchao' or 'bnb'." + ) + + if not weight_dtype: + return data + + adapter = data.get("adapter") + + if backend == "torchao": + # torchao: any quantized weight_dtype means qlora + if adapter == "lora": + data["adapter"] = "qlora" + + elif backend == "bnb": + if weight_dtype == "nf4": + # bnb nf4 = qlora with load_in_4bit + if adapter == "lora": + data["adapter"] = "qlora" + data.setdefault("load_in_4bit", True) + elif weight_dtype == "int8": + # bnb int8 = lora with load_in_8bit + data.setdefault("load_in_8bit", True) + else: + raise ValueError( + f"peft.weight_dtype '{weight_dtype}' is not supported with bnb backend. " + "Supported: nf4, int8." + ) + + return data + @model_validator(mode="before") @classmethod def validate_adapter(cls, data): @@ -173,6 +249,8 @@ class LoraConfig(BaseModel): @model_validator(mode="after") def validate_qlora(self): if self.adapter == "qlora": + is_torchao = self.peft and self.peft.backend == "torchao" + if self.merge_lora: # can't merge qlora if loaded in 8bit or 4bit if self.load_in_8bit: @@ -184,7 +262,20 @@ class LoraConfig(BaseModel): if self.load_in_4bit: raise ValueError("Can't merge qlora if loaded in 4bit") + elif is_torchao: + # torchao backend: validate torchao-specific requirements + if self.load_in_4bit or self.load_in_8bit: + raise ValueError( + "load_in_4bit/load_in_8bit are for bitsandbytes. " + "With peft.backend: torchao, quantization is handled by torchao." + ) + if not self.peft.weight_dtype: + raise ValueError( + "peft.weight_dtype is required when peft.backend is 'torchao'" + ) + else: + # Default bnb path if self.load_in_8bit: raise ValueError("Can't load qlora in 8bit") diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index a7c130574..2a05eea8d 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -16,6 +16,8 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None: return TorchAOQuantDType.int4 if v == "int8": return TorchAOQuantDType.int8 + if v == "nf4": + return TorchAOQuantDType.nf4 if v in ["float8_e4m3fn", "fp8", "float8"]: return TorchAOQuantDType.float8_e4m3fn if v == "nvfp4":