diff --git a/docs/config.qmd b/docs/config.qmd index fae64501a..857d0eb03 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -32,6 +32,8 @@ tokenizer_legacy: resize_token_embeddings_to_32x: # Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. shrink_embeddings: +# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs +embeddings_skip_upcast: # Whether to load the model with randomly initialized weights. Useful for # pre-training a model from scratch or debugging purposes. random_init_weights: diff --git a/src/axolotl/monkeypatch/peft/__init__.py b/src/axolotl/monkeypatch/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py new file mode 100644 index 000000000..b3703d398 --- /dev/null +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -0,0 +1,78 @@ +""" +Patch prepare_model_for_kbit_training to not upcast everything +""" + +import inspect +import logging + +import peft + +import axolotl +from axolotl.monkeypatch.utils import detab_code + +LOG = logging.getLogger(__name__) + +ORIGINAL_PREPARE_CODE = """ + for param in model.parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit": + param.data = param.data.to(torch.float32) +""" + +PATCHED_PREPARE_CODE = """ + for name, param in model.named_parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]): + param.data = param.data.to(torch.float32) +""" + + +def get_peft_prep_code() -> str: + prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training) + return prepare + + +def check_peft_prep_code_is_patchable() -> bool: + prep_code = get_peft_prep_code() + prep_code, _ = detab_code(prep_code) + return ORIGINAL_PREPARE_CODE in prep_code + + +def patch_peft_prep_code(): + """ + monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs + """ + + try: + prep_code = get_peft_prep_code() + except OSError: + return + peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access + prep_code + ) + prep_code, _ = detab_code(prep_code) + if ORIGINAL_PREPARE_CODE not in prep_code: + return + + prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE) + prep_code = prep_code.replace( + "def prepare_model_for_kbit_training(", + "def fixed_prepare_model_for_kbit_training(", + 1, + ) + + items_to_import = [] + for item in dir(peft.utils.other): + if item in prep_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")", + globals(), + ) + exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching prepare_model_for_kbit_training to allow for overrides") + peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 92766d44c..6aa4dd162 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -566,6 +566,11 @@ class ModelLoader: patch_accelerate_fsdp_utils() + if self.cfg.adapter and self.cfg.embeddings_skip_upcast: + from axolotl.monkeypatch.peft.utils import patch_peft_prep_code + + patch_peft_prep_code() + if self.cfg.flex_attention: from axolotl.monkeypatch.attention.flex_attn import ( patch_flex_make_mask, @@ -1185,7 +1190,7 @@ class ModelLoader: ], ) - def prepare_model(self, qlora_fsdp) -> None: + def prepare_model(self, qlora_fsdp: bool) -> None: skip_prepare_model_for_kbit_training = False if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled @@ -1315,7 +1320,10 @@ class ModelLoader: # make sure these are fp32 per Ramesh et al. (2021) embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) if not self.cfg.fsdp: - # FSDP doesn't like mixed Float and BFloat16 + # we don't run this during FSDP because this will leave mixed + # float and bfloat16 dtypes in the model which FSDP doesn't like + if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: + embedding_modules = [] self.convert_embedding_modules_dtype( embedding_modules, dist_dtype=torch.float32, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3527ec56e..9db374409 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -82,6 +82,7 @@ class AxolotlInputConfig( mean_resizing_embeddings: bool | None = False # optionally shrink the embeddings when the tokenizer vocab size is smaller shrink_embeddings: bool | None = None + embeddings_skip_upcast: bool | None = None rl: RLType | None = None trl: TRLConfig | None = Field( diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py new file mode 100644 index 000000000..d4f59a128 --- /dev/null +++ b/tests/e2e/patched/test_peft_embeddings.py @@ -0,0 +1,63 @@ +""" +Test case for handling embeddings when using peft +""" + +import torch + +from axolotl.train import setup_model_and_tokenizer +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +class TestLlamaPeftEmbeddings: + """ + test class for handling embeddings when using peft + """ + + def test_peft_embeddings_upcast(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "load_in_4bit": True, + "adapter": "qlora", + "lora_r": 8, + "lora_alpha": 16, + "lora_target_linear": True, + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": False, + "bf16": "auto", + "save_safetensors": True, + "embeddings_skip_upcast": True, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + + model, _, _, _ = setup_model_and_tokenizer(cfg) + + # Check if the embeddings are upcast correctly + # only embed_tokens is a parameter that may be upcast + assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16 + assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16