feat: add torchao's int4, nf4, int8
This commit is contained in:
@@ -23,6 +23,7 @@ from axolotl.loaders.utils import get_linear_embedding_layers
|
|||||||
from axolotl.telemetry.errors import send_errors
|
from axolotl.telemetry.errors import send_errors
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -134,11 +135,13 @@ def load_lora(
|
|||||||
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
|
is_torchao = cfg.peft and cfg.peft.backend == "torchao"
|
||||||
if (
|
if (
|
||||||
cfg.fsdp_config
|
cfg.fsdp_config
|
||||||
and cfg.adapter
|
and cfg.adapter
|
||||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
and rank != 0
|
and rank != 0
|
||||||
|
and not is_torchao
|
||||||
):
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
|
|
||||||
@@ -146,6 +149,15 @@ def load_lora(
|
|||||||
if cfg.peft_autocast_adapter_dtype is not None:
|
if cfg.peft_autocast_adapter_dtype is not None:
|
||||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||||
|
|
||||||
|
# Patch PEFT's torchao dispatch before any model creation/loading.
|
||||||
|
# Must happen before both get_peft_model and PeftModel.from_pretrained,
|
||||||
|
# as both trigger LoRA layer dispatch that would fail for INT4/NF4 weights.
|
||||||
|
# INT8 is natively supported by PEFT's TorchaoLoraLinear, so skip the patch.
|
||||||
|
if is_torchao and cfg.peft.weight_dtype != TorchAOQuantDType.int8:
|
||||||
|
from axolotl.monkeypatch.peft.utils import patch_peft_torchao_dispatch
|
||||||
|
|
||||||
|
patch_peft_torchao_dispatch()
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||||
if cfg.lora_on_cpu:
|
if cfg.lora_on_cpu:
|
||||||
@@ -172,6 +184,7 @@ def load_lora(
|
|||||||
and cfg.adapter
|
and cfg.adapter
|
||||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
and rank != 0
|
and rank != 0
|
||||||
|
and not is_torchao
|
||||||
):
|
):
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,15 @@ class ModelLoader:
|
|||||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||||
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_torchao_qlora(self):
|
||||||
|
"""Property that determines if torchao backend is used for QLoRA."""
|
||||||
|
return (
|
||||||
|
self.cfg.adapter == "qlora"
|
||||||
|
and self.cfg.peft
|
||||||
|
and self.cfg.peft.backend == "torchao"
|
||||||
|
)
|
||||||
|
|
||||||
@send_errors
|
@send_errors
|
||||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||||
"""Load and prepare the model with all configurations and patches.
|
"""Load and prepare the model with all configurations and patches.
|
||||||
@@ -491,8 +500,9 @@ class ModelLoader:
|
|||||||
|
|
||||||
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
|
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
|
# For QLoRA + FSDP with bnb, we still need to set device_map for proper initialization
|
||||||
if self.is_qlora_and_fsdp_enabled:
|
# torchao tensors work natively with FSDP2, no device_map override needed
|
||||||
|
if self.is_qlora_and_fsdp_enabled and not self.is_torchao_qlora:
|
||||||
self.model_kwargs["device_map"] = {
|
self.model_kwargs["device_map"] = {
|
||||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||||
}
|
}
|
||||||
@@ -561,6 +571,44 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.cfg.adapter == "qlora"
|
||||||
|
and self.cfg.peft
|
||||||
|
and self.cfg.peft.backend == "torchao"
|
||||||
|
and not self.cfg.merge_lora
|
||||||
|
):
|
||||||
|
from transformers import TorchAoConfig
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
|
weight_dtype = self.cfg.peft.weight_dtype
|
||||||
|
if weight_dtype == TorchAOQuantDType.int4:
|
||||||
|
group_size = self.cfg.peft.group_size or 128
|
||||||
|
self.model_kwargs["quantization_config"] = TorchAoConfig(
|
||||||
|
quant_type="int4_weight_only",
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
elif weight_dtype == TorchAOQuantDType.int8:
|
||||||
|
group_size = self.cfg.peft.group_size or 128
|
||||||
|
self.model_kwargs["quantization_config"] = TorchAoConfig(
|
||||||
|
quant_type="int8_weight_only",
|
||||||
|
group_size=group_size,
|
||||||
|
)
|
||||||
|
elif weight_dtype == TorchAOQuantDType.nf4:
|
||||||
|
from torchao.dtypes._nf4tensor_api import NF4WeightOnlyConfig
|
||||||
|
|
||||||
|
block_size = self.cfg.peft.group_size or 64
|
||||||
|
self.model_kwargs["quantization_config"] = TorchAoConfig(
|
||||||
|
quant_type=NF4WeightOnlyConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
scaler_block_size=256,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported torchao weight_dtype for QLoRA: {weight_dtype}. "
|
||||||
|
"Supported: int4, int8, nf4"
|
||||||
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
@@ -860,6 +908,10 @@ class ModelLoader:
|
|||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
# torchao quantized models don't use Params4bit and don't need kbit preparation
|
||||||
|
if self.is_torchao_qlora:
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
|
|||||||
@@ -348,10 +348,12 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp2_bnb_patches(self):
|
def _apply_fsdp2_bnb_patches(self):
|
||||||
"""Apply FSDP2 BNB patches."""
|
"""Apply FSDP2 BNB patches."""
|
||||||
|
is_torchao = self.cfg.peft and self.cfg.peft.backend == "torchao"
|
||||||
if (
|
if (
|
||||||
self.cfg.fsdp_config
|
self.cfg.fsdp_config
|
||||||
and str(self.cfg.fsdp_version) == "2"
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
and self.cfg.adapter == "qlora"
|
and self.cfg.adapter == "qlora"
|
||||||
|
and not is_torchao
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
|
|||||||
@@ -78,3 +78,30 @@ def patch_peft_prep_code():
|
|||||||
axolotl.loaders.model.prepare_model_for_kbit_training = (
|
axolotl.loaders.model.prepare_model_for_kbit_training = (
|
||||||
fixed_prepare_model_for_kbit_training
|
fixed_prepare_model_for_kbit_training
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_torchao_dispatch():
|
||||||
|
"""Skip PEFT's TorchaoLoraLinear for non-INT8 torchao weights.
|
||||||
|
|
||||||
|
PEFT's dispatch_torchao() matches AffineQuantizedTensor but then errors in
|
||||||
|
_check_dtype_supported() because it only allows INT8. Our LoRA kernels handle
|
||||||
|
dequantization explicitly, so we bypass PEFT's torchao dispatch entirely and
|
||||||
|
let it fall back to standard Linear LoRA layers.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from peft.tuners.lora import torchao as peft_torchao
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("Could not import peft.tuners.lora.torchao for patching")
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(peft_torchao, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
def patched_dispatch(target, adapter_name, lora_config, **kwargs):
|
||||||
|
# Return None so PEFT falls back to standard Linear LoRA layers.
|
||||||
|
# Our LoRA kernels handle torchao dequantization explicitly.
|
||||||
|
return None
|
||||||
|
|
||||||
|
peft_torchao.dispatch_torchao = patched_dispatch
|
||||||
|
peft_torchao._axolotl_patched = True
|
||||||
|
LOG.info("Patched PEFT dispatch_torchao to skip TorchaoLoraLinear")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch
|
|||||||
class TorchAOQuantDType(Enum):
|
class TorchAOQuantDType(Enum):
|
||||||
int4 = torch.int4
|
int4 = torch.int4
|
||||||
int8 = torch.int8
|
int8 = torch.int8
|
||||||
|
nf4 = "nf4"
|
||||||
float8_e4m3fn = torch.float8_e4m3fn
|
float8_e4m3fn = torch.float8_e4m3fn
|
||||||
nvfp4 = "nvfp4"
|
nvfp4 = "nvfp4"
|
||||||
|
|
||||||
@@ -16,6 +17,8 @@ class TorchAOQuantDType(Enum):
|
|||||||
return TorchAOQuantDType.int4
|
return TorchAOQuantDType.int4
|
||||||
if str == "int8":
|
if str == "int8":
|
||||||
return TorchAOQuantDType.int8
|
return TorchAOQuantDType.int8
|
||||||
|
if str == "nf4":
|
||||||
|
return TorchAOQuantDType.nf4
|
||||||
if str in ["float8_e4m3fn", "fp8", "float8"]:
|
if str in ["float8_e4m3fn", "fp8", "float8"]:
|
||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if str == "nvfp4":
|
if str == "nvfp4":
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
from axolotl.utils.schemas.quantization import validate_ao_dtype
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
@@ -15,7 +18,7 @@ class LoftQConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PeftConfig(BaseModel):
|
class PeftConfig(BaseModel):
|
||||||
"""peftq configuration subset"""
|
"""PEFT configuration subset"""
|
||||||
|
|
||||||
loftq_config: LoftQConfig | None = Field(
|
loftq_config: LoftQConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -23,6 +26,29 @@ class PeftConfig(BaseModel):
|
|||||||
"description": "Configuration options for loftq initialization for LoRA"
|
"description": "Configuration options for loftq initialization for LoRA"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
backend: Literal["bnb", "torchao"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Quantization backend for QLoRA. 'bnb' for bitsandbytes (default), 'torchao' for torchao."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
weight_dtype: TorchAOQuantDType | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Weight quantization dtype (int4, int8, or nf4). Also used with bnb backend to auto-configure quantization."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
group_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Group size for quantization. Defaults to 128 for int4, 64 for nf4."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("weight_dtype", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_weight_dtype(cls, v):
|
||||||
|
return validate_ao_dtype(v)
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
@@ -156,6 +182,56 @@ class LoraConfig(BaseModel):
|
|||||||
|
|
||||||
merge_lora: bool | None = None
|
merge_lora: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def auto_detect_qlora(cls, data):
|
||||||
|
"""Auto-set adapter type and quantization flags from peft config.
|
||||||
|
|
||||||
|
When peft.backend and peft.weight_dtype are set, this infers the correct
|
||||||
|
adapter type and internal flags (load_in_4bit, load_in_8bit) so users
|
||||||
|
don't need to set them manually.
|
||||||
|
"""
|
||||||
|
peft = data.get("peft")
|
||||||
|
if not isinstance(peft, dict):
|
||||||
|
return data
|
||||||
|
|
||||||
|
backend = peft.get("backend")
|
||||||
|
weight_dtype = peft.get("weight_dtype")
|
||||||
|
|
||||||
|
# Validate: weight_dtype requires backend
|
||||||
|
if weight_dtype and not backend:
|
||||||
|
raise ValueError(
|
||||||
|
"peft.backend is required when peft.weight_dtype is set. "
|
||||||
|
"Use 'torchao' or 'bnb'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not weight_dtype:
|
||||||
|
return data
|
||||||
|
|
||||||
|
adapter = data.get("adapter")
|
||||||
|
|
||||||
|
if backend == "torchao":
|
||||||
|
# torchao: any quantized weight_dtype means qlora
|
||||||
|
if adapter == "lora":
|
||||||
|
data["adapter"] = "qlora"
|
||||||
|
|
||||||
|
elif backend == "bnb":
|
||||||
|
if weight_dtype == "nf4":
|
||||||
|
# bnb nf4 = qlora with load_in_4bit
|
||||||
|
if adapter == "lora":
|
||||||
|
data["adapter"] = "qlora"
|
||||||
|
data.setdefault("load_in_4bit", True)
|
||||||
|
elif weight_dtype == "int8":
|
||||||
|
# bnb int8 = lora with load_in_8bit
|
||||||
|
data.setdefault("load_in_8bit", True)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"peft.weight_dtype '{weight_dtype}' is not supported with bnb backend. "
|
||||||
|
"Supported: nf4, int8."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_adapter(cls, data):
|
def validate_adapter(cls, data):
|
||||||
@@ -173,6 +249,8 @@ class LoraConfig(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_qlora(self):
|
def validate_qlora(self):
|
||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
|
is_torchao = self.peft and self.peft.backend == "torchao"
|
||||||
|
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
if self.load_in_8bit:
|
if self.load_in_8bit:
|
||||||
@@ -184,7 +262,20 @@ class LoraConfig(BaseModel):
|
|||||||
if self.load_in_4bit:
|
if self.load_in_4bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
elif is_torchao:
|
||||||
|
# torchao backend: validate torchao-specific requirements
|
||||||
|
if self.load_in_4bit or self.load_in_8bit:
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_4bit/load_in_8bit are for bitsandbytes. "
|
||||||
|
"With peft.backend: torchao, quantization is handled by torchao."
|
||||||
|
)
|
||||||
|
if not self.peft.weight_dtype:
|
||||||
|
raise ValueError(
|
||||||
|
"peft.weight_dtype is required when peft.backend is 'torchao'"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Default bnb path
|
||||||
if self.load_in_8bit:
|
if self.load_in_8bit:
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ def validate_ao_dtype(v: Any) -> TorchAOQuantDType | None:
|
|||||||
return TorchAOQuantDType.int4
|
return TorchAOQuantDType.int4
|
||||||
if v == "int8":
|
if v == "int8":
|
||||||
return TorchAOQuantDType.int8
|
return TorchAOQuantDType.int8
|
||||||
|
if v == "nf4":
|
||||||
|
return TorchAOQuantDType.nf4
|
||||||
if v in ["float8_e4m3fn", "fp8", "float8"]:
|
if v in ["float8_e4m3fn", "fp8", "float8"]:
|
||||||
return TorchAOQuantDType.float8_e4m3fn
|
return TorchAOQuantDType.float8_e4m3fn
|
||||||
if v == "nvfp4":
|
if v == "nvfp4":
|
||||||
|
|||||||
Reference in New Issue
Block a user