fsdp embeddings should be float32 per comment
This commit is contained in:
@@ -1309,7 +1309,7 @@ class ModelLoader:
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||
if not self.cfg.fsdp:
|
||||
if self.cfg.fsdp:
|
||||
# FSDP doesn't like mixed Float and BFloat16
|
||||
self.convert_embedding_modules_dtype(
|
||||
embedding_modules,
|
||||
|
||||
Reference in New Issue
Block a user