diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index dca688bb2..3b64b23db 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -142,9 +142,12 @@ def load_lora( ): setup_quantized_meta_for_peft(model) + model_kwargs: Any = {} + if cfg.peft_autocast_adapter_dtype is not None: + model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} model_kwargs["device_map"] = {"": "cpu"} @@ -155,7 +158,7 @@ def load_lora( **model_kwargs, ) else: - model = get_peft_model(model, lora_config) + model = get_peft_model(model, lora_config, **model_kwargs) if rank == 0: try: diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index fd16dec3f..a9ce1fbd6 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -109,6 +109,12 @@ class LoraConfig(BaseModel): ) }, ) + peft_autocast_adapter_dtype: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT." + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False,