add gptneox embeddings, fix phi2 inputs, also fix the casting (#1083)
This commit is contained in:
@@ -8,5 +8,7 @@ def get_linear_embedding_layers(model_type):
|
||||
returns the linear embedding layers needed for loras, dependent on the model arch
|
||||
"""
|
||||
if model_type == "phi-msft":
|
||||
return ["embd", "lm_head.linear"]
|
||||
return ["lm_head", "embed_tokens"]
|
||||
return ["embd.wte", "lm_head.linear"]
|
||||
if model_type == "gpt_neox":
|
||||
return ["embed_in", "embed_out"]
|
||||
return ["embed_tokens", "lm_head"]
|
||||
|
||||
@@ -588,13 +588,14 @@ def load_model(
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch.float32)
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if any(m in name for m in embedding_modules):
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
@@ -619,15 +620,12 @@ def load_model(
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype or (
|
||||
cfg.flash_attention
|
||||
and (cfg.is_llama_derived_model or cfg.is_mistral_derived_model)
|
||||
):
|
||||
if needs_fa2_dtype or cfg.flash_attention:
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(cfg.torch_dtype)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if any(m in name for m in embedding_modules):
|
||||
if hasattr(module, "weight"):
|
||||
module.to(cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def fixture_cfg():
|
||||
"adam_epsilon": 0.00001,
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"model_config_type": "llama",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -770,7 +770,7 @@ class ValidationCheckModelConfig(BaseValidation):
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
"tokens": ["<|imstart|>"],
|
||||
"lora_modules_to_save": ["embd", "lm_head.linear"],
|
||||
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user