Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
This commit is contained in:
@@ -368,7 +368,7 @@ def load_model(
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if (fix_dtype or not cfg.adapter) and (
|
||||
if fix_dtype and (
|
||||
cfg.flash_attention and cfg.is_llama_derived_model
|
||||
):
|
||||
for name, module in model.named_modules():
|
||||
|
||||
Reference in New Issue
Block a user