fsdp embeddings should be float32 per comment

This commit is contained in:
Wing Lian
2025-05-03 01:56:09 -04:00
parent ed922796b7
commit 37c27aedc1

View File

@@ -1309,7 +1309,7 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if not self.cfg.fsdp:
if self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
self.convert_embedding_modules_dtype(
embedding_modules,