patch peft to not upcast everything
This commit is contained in:
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
77
src/axolotl/monkeypatch/peft/utils.py
Normal file
77
src/axolotl/monkeypatch/peft/utils.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
Patch prepare_model_for_kbit_training to not upcast everything
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import peft
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_PREPARE_CODE = """
|
||||||
|
for param in model.parameters():
|
||||||
|
if (
|
||||||
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
|
) and param.__class__.__name__ != "Params4bit":
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_PREPARE_CODE = """
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if (
|
||||||
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||||
|
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_peft_prep_code() -> str:
|
||||||
|
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
|
||||||
|
return prepare
|
||||||
|
|
||||||
|
|
||||||
|
def check_peft_prep_code_is_patchable() -> bool:
|
||||||
|
prep_code = get_peft_prep_code()
|
||||||
|
prep_code, _ = detab_code(prep_code)
|
||||||
|
return ORIGINAL_PREPARE_CODE in prep_code
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_prep_code():
|
||||||
|
"""
|
||||||
|
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
prep_code = get_peft_prep_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
|
||||||
|
prep_code
|
||||||
|
)
|
||||||
|
prep_code, _ = detab_code(prep_code)
|
||||||
|
if ORIGINAL_PREPARE_CODE not in prep_code:
|
||||||
|
return
|
||||||
|
|
||||||
|
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
|
||||||
|
prep_code = prep_code.replace(
|
||||||
|
"def prepare_model_for_kbit_training(",
|
||||||
|
"def fixed_prepare_model_for_kbit_training(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(peft.utils.other):
|
||||||
|
if item in prep_code:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
|
||||||
|
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||||
@@ -561,6 +561,11 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_accelerate_fsdp_utils()
|
patch_accelerate_fsdp_utils()
|
||||||
|
|
||||||
|
if self.cfg.adapter:
|
||||||
|
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
|
||||||
|
|
||||||
|
patch_peft_prep_code()
|
||||||
|
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.flex_attention:
|
||||||
from axolotl.monkeypatch.attention.flex_attn import (
|
from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
patch_flex_make_mask,
|
patch_flex_make_mask,
|
||||||
@@ -1180,7 +1185,7 @@ class ModelLoader:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_model(self, qlora_fsdp) -> None:
|
def prepare_model(self, qlora_fsdp: bool) -> None:
|
||||||
skip_prepare_model_for_kbit_training = False
|
skip_prepare_model_for_kbit_training = False
|
||||||
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
|
|||||||
Reference in New Issue
Block a user