From 140083a828c8427109eb89f366b801b2307a41a3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 3 May 2025 02:32:47 -0400 Subject: [PATCH] patch peft to not upcast everything --- src/axolotl/monkeypatch/peft/__init__.py | 0 src/axolotl/monkeypatch/peft/utils.py | 77 ++++++++++++++++++++++++ src/axolotl/utils/models.py | 7 ++- 3 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/peft/__init__.py create mode 100644 src/axolotl/monkeypatch/peft/utils.py diff --git a/src/axolotl/monkeypatch/peft/__init__.py b/src/axolotl/monkeypatch/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py new file mode 100644 index 000000000..fed88a0ed --- /dev/null +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -0,0 +1,77 @@ +""" +Patch prepare_model_for_kbit_training to not upcast everything +""" + +import inspect +import logging + +import peft + +from axolotl.monkeypatch.utils import detab_code + +LOG = logging.getLogger(__name__) + +ORIGINAL_PREPARE_CODE = """ + for param in model.parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit": + param.data = param.data.to(torch.float32) +""" + +PATCHED_PREPARE_CODE = """ + for name, param in model.named_parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit" and "norm" in name: + param.data = param.data.to(torch.float32) +""" + + +def get_peft_prep_code() -> str: + prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training) + return prepare + + +def check_peft_prep_code_is_patchable() -> bool: + prep_code = get_peft_prep_code() + prep_code, _ = detab_code(prep_code) + return ORIGINAL_PREPARE_CODE in prep_code + + +def patch_peft_prep_code(): + """ + monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + """ + + try: + prep_code = get_peft_prep_code() + except OSError: + return + peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access + prep_code + ) + prep_code, _ = detab_code(prep_code) + if ORIGINAL_PREPARE_CODE not in prep_code: + return + + prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE) + prep_code = prep_code.replace( + "def prepare_model_for_kbit_training(", + "def fixed_prepare_model_for_kbit_training(", + 1, + ) + + items_to_import = [] + for item in dir(peft.utils.other): + if item in prep_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")", + globals(), + ) + exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching prepare_model_for_kbit_training to allow for overrides") + peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8ba26543c..eaaa2a450 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -561,6 +561,11 @@ class ModelLoader: patch_accelerate_fsdp_utils() + if self.cfg.adapter: + from axolotl.monkeypatch.peft.utils import patch_peft_prep_code + + patch_peft_prep_code() + if self.cfg.flex_attention: from axolotl.monkeypatch.attention.flex_attn import ( patch_flex_make_mask, @@ -1180,7 +1185,7 @@ class ModelLoader: ], ) - def prepare_model(self, qlora_fsdp) -> None: + def prepare_model(self, qlora_fsdp: bool) -> None: skip_prepare_model_for_kbit_training = False if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled