From f311df9462bf348317de57404a9abe6305da19d3 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:34:11 +0200 Subject: [PATCH 1/7] fix: finetune model inference needs the dtype fix to work with flash-attn --- src/axolotl/utils/models.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 261acd934..c95e346e1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,6 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) + fix_dtype = False if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -363,16 +364,19 @@ def load_model( model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) + fix_dtype = True - # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if cfg.flash_attention and cfg.is_llama_derived_model: - for name, module in model.named_modules(): - if "norm" in name: + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if (fix_dtype or cfg.adapter == "" or cfg.adapter == None) and ( + cfg.flash_attention and cfg.is_llama_derived_model + ): + for name, module in model.named_modules(): + if "norm" in name: + module.to(cfg.torch_dtype) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): module.to(cfg.torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(cfg.torch_dtype) model, lora_config = load_adapter(model, cfg, cfg.adapter) From a184549e4c44651555170eac5dc3384842d34112 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:36:14 +0200 Subject: [PATCH 2/7] ignore: linter --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c95e346e1..fc2cf04b3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -368,7 +368,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if (fix_dtype or cfg.adapter == "" or cfg.adapter == None) and ( + if (fix_dtype or cfg.adapter == "" or cfg.adapter is None) and ( cfg.flash_attention and cfg.is_llama_derived_model ): for name, module in model.named_modules(): From d03887fad5044a90b1984baaad36387079ecd4f6 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:45:45 +0200 Subject: [PATCH 3/7] ignore: address pr review --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fc2cf04b3..71e27a2bc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -368,7 +368,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if (fix_dtype or cfg.adapter == "" or cfg.adapter is None) and ( + if (fix_dtype or not cfg.adapter) and ( cfg.flash_attention and cfg.is_llama_derived_model ): for name, module in model.named_modules(): From 9e699683d79a21aeffde7970f4af07febbd341e8 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:01:37 +0200 Subject: [PATCH 4/7] Update src/axolotl/utils/models.py Co-authored-by: Aman Gupta Karmani --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 71e27a2bc..ed917d963 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,7 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) - fix_dtype = False + fix_dtype = not cfg.adapter if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) From 7fd662dd89e4fb8e97a7b1fbb4328f33220f60c1 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sun, 27 Aug 2023 21:01:43 +0200 Subject: [PATCH 5/7] Update src/axolotl/utils/models.py Co-authored-by: Aman Gupta Karmani --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ed917d963..4575f5966 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -368,7 +368,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if (fix_dtype or not cfg.adapter) and ( + if fix_dtype and ( cfg.flash_attention and cfg.is_llama_derived_model ): for name, module in model.named_modules(): From f319b0bc67b548f509ca8ddc3922c028c733bea7 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 27 Aug 2023 19:55:11 +0000 Subject: [PATCH 6/7] rename var and reformat --- src/axolotl/utils/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4575f5966..dd75106ec 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,7 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) - fix_dtype = not cfg.adapter + needs_fa2_dtype = not cfg.adapter if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -364,13 +364,11 @@ def load_model( model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - fix_dtype = True + needs_fa2_dtype = True # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - if fix_dtype and ( - cfg.flash_attention and cfg.is_llama_derived_model - ): + if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): for name, module in model.named_modules(): if "norm" in name: module.to(cfg.torch_dtype) From 3a011ea1ef4ddee446e22849651783dd758dfda6 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Sun, 27 Aug 2023 20:09:26 +0000 Subject: [PATCH 7/7] fix condition and add logging --- src/axolotl/utils/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dd75106ec..c2fbc19e3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,7 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) - needs_fa2_dtype = not cfg.adapter + needs_fa2_dtype = cfg.adapter is not None if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -369,6 +369,7 @@ def load_model( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. if needs_fa2_dtype and (cfg.flash_attention and cfg.is_llama_derived_model): + LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) for name, module in model.named_modules(): if "norm" in name: module.to(cfg.torch_dtype)