From 2cf254b4afc2c3eea632a09448bd6b999ca50e93 Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Wed, 17 Dec 2025 10:09:39 -0500 Subject: [PATCH] Add `peft_autocast_adapter_dtype` config option (#3311) [skip ci] * Add `peft_autocast_adapter_dtype` field to schema * Add `autocast_adapter_dtype` to `model_kwargs` * chore: docs --------- Co-authored-by: NanoCode012 --- src/axolotl/loaders/adapter.py | 7 +++++-- src/axolotl/utils/schemas/peft.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index dca688bb2..3b64b23db 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -142,9 +142,12 @@ def load_lora( ): setup_quantized_meta_for_peft(model) + model_kwargs: Any = {} + if cfg.peft_autocast_adapter_dtype is not None: + model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} model_kwargs["device_map"] = {"": "cpu"} @@ -155,7 +158,7 @@ def load_lora( **model_kwargs, ) else: - model = get_peft_model(model, lora_config) + model = get_peft_model(model, lora_config, **model_kwargs) if rank == 0: try: diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index fd16dec3f..a9ce1fbd6 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -109,6 +109,12 @@ class LoraConfig(BaseModel): ) }, ) + peft_autocast_adapter_dtype: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT." + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False,