add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)

This commit is contained in:
Wing Lian
2024-01-10 22:32:43 -05:00
committed by GitHub
parent 23495a80af
commit 78c5b1979e
4 changed files with 10 additions and 9 deletions

View File

@@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
returns the linear embedding layers needed for loras, dependent on the model arch
"""
if model_type == "phi-msft":
return ["embd", "lm_head.linear"]
return ["lm_head", "embed_tokens"]
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
return ["embed_tokens", "lm_head"]

View File

@@ -588,13 +588,14 @@ def load_model(
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
for name, module in model.named_modules():
if "norm" in name:
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)
@@ -619,15 +620,12 @@ def load_model(
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (
cfg.flash_attention
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
):
if needs_fa2_dtype or cfg.flash_attention:
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
module.to(cfg.torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)

View File

@@ -30,6 +30,7 @@ def fixture_cfg():
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
}
)

View File

@@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation):
"adapter": "qlora",
"load_in_4bit": True,
"tokens": ["<|imstart|>"],
"lora_modules_to_save": ["embd", "lm_head.linear"],
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
}
)