diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py index f9ea91727..b5d2f7cc9 100644 --- a/src/axolotl/utils/lora_embeddings.py +++ b/src/axolotl/utils/lora_embeddings.py @@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type): returns the linear embedding layers needed for loras, dependent on the model arch """ if model_type == "phi-msft": - return ["embd", "lm_head.linear"] - return ["lm_head", "embed_tokens"] + return ["embd.wte", "lm_head.linear"] + if model_type == "gpt_neox": + return ["embed_in", "embed_out"] + return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0e7633a3b..41c43a029 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -588,13 +588,14 @@ def load_model( log_gpu_memory_usage(LOG, "after model load", model.device) # make sure these are fp32 per Ramesh et al. (2021) + embedding_modules = get_linear_embedding_layers(cfg.model_config_type) for name, module in model.named_modules(): if "norm" in name: module.to(torch.float32) if model_config.model_type == "btlm": # don't upcast lm_head for btlm continue - if "lm_head" in name or "embed_tokens" in name: + if any(m in name for m in embedding_modules): if hasattr(module, "weight"): module.to(torch.float32) @@ -619,15 +620,12 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if needs_fa2_dtype or ( - cfg.flash_attention - and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model) - ): + if needs_fa2_dtype or cfg.flash_attention: LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: module.to(cfg.torch_dtype) - if "lm_head" in name or "embed_tokens" in name: + if any(m in name for m in embedding_modules): if hasattr(module, "weight"): module.to(cfg.torch_dtype) diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index e8987ef45..19042639f 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -30,6 +30,7 @@ def fixture_cfg(): "adam_epsilon": 0.00001, "dataloader_num_workers": 1, "dataloader_pin_memory": True, + "model_config_type": "llama", } ) diff --git a/tests/test_validation.py b/tests/test_validation.py index c952b7fcf..79e7e73a6 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation): "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], - "lora_modules_to_save": ["embd", "lm_head.linear"], + "lora_modules_to_save": ["embd.wte", "lm_head.linear"], } )