From 744f7082f51dc205097963343a9e498a46d82560 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Apr 2024 17:02:54 -0700 Subject: [PATCH] fix for fsdp for models that aren't qwen2 or jamba --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 295adefa5..b29ddef3a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -459,7 +459,7 @@ def load_model( "bnb_4bit_quant_type": "nf4", "bnb_4bit_quant_storage": torch.bfloat16, } - if not cfg.deepspeed: + if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"): # for some reason, this causes the loss to be off by an order of magnitude # but deepspeed needs this still in bfloat16 bnb_config["bnb_4bit_quant_storage"] = torch.float32