Add peft_autocast_adapter_dtype config option (#3311) [skip ci]
* Add `peft_autocast_adapter_dtype` field to schema * Add `autocast_adapter_dtype` to `model_kwargs` * chore: docs --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -142,9 +142,12 @@ def load_lora(
|
|||||||
):
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
|
|
||||||
|
model_kwargs: Any = {}
|
||||||
|
if cfg.peft_autocast_adapter_dtype is not None:
|
||||||
|
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||||
model_kwargs: Any = {}
|
|
||||||
if cfg.lora_on_cpu:
|
if cfg.lora_on_cpu:
|
||||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||||
model_kwargs["device_map"] = {"": "cpu"}
|
model_kwargs["device_map"] = {"": "cpu"}
|
||||||
@@ -155,7 +158,7 @@ def load_lora(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config, **model_kwargs)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -109,6 +109,12 @@ class LoraConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
peft_autocast_adapter_dtype: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user