diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ba71ea459..8ba26543c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1309,7 +1309,7 @@ class ModelLoader: # make sure these are fp32 per Ramesh et al. (2021) embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) - if not self.cfg.fsdp: + if self.cfg.fsdp: # FSDP doesn't like mixed Float and BFloat16 self.convert_embedding_modules_dtype( embedding_modules,