Files
axolotl/src/axolotl/utils/schemas/peft.py
Dan Saunders 113e9cd193 Autodoc generation with quartodoc (#2419)
* quartodoc integration

* quartodoc progress

* deletions

* Update docs/.gitignore to exclude auto-generated API documentation files

* Fix

* more autodoc progress

* moving reference up near the top of the sidebar

* fix broken link

* update to reflect recent changes

* pydantic models refactor + add to autodoc + fixes

* fix

* shrinking header sizes

* fix accidental change

* include quartodoc build step

* update pre-commit version

* update pylint

* pre-commit

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-03-21 12:26:47 -04:00

133 lines
4.4 KiB
Python

"""Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: LoftQConfig | None = None
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
adapter: str | None = None
lora_model_dir: str | None = None
lora_r: int | None = None
lora_alpha: int | None = None
lora_fan_in_fan_out: bool | None = None
lora_target_modules: str | list[str] | None = None
lora_target_linear: bool | None = None
lora_modules_to_save: list[str] | None = None
lora_dropout: float | None = 0.0
peft_layers_to_transform: list[int] | None = None
peft_layers_pattern: list[str] | None = None
peft: PeftConfig | None = None
peft_use_dora: bool | None = None
peft_use_rslora: bool | None = None
peft_layer_replication: list[tuple[int, int]] | None = None
peft_init_lora_weights: bool | str | None = None
qlora_sharded_model_loading: bool | None = Field(
default=False,
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: float | None = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
)
merge_lora: bool | None = None
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@classmethod
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = None
relora_warmup_steps: int | None = None
relora_anneal_steps: int | None = None
relora_prune_ratio: float | None = None
relora_cpu_offload: bool | None = None