don't shrink embeddings unless told to

This commit is contained in:
Wing Lian
2025-02-03 10:10:22 -05:00
parent 3c7517fd55
commit 2c1376d8c4
2 changed files with 8 additions and 3 deletions

View File

@@ -667,6 +667,8 @@ class AxolotlInputConfig(
auto_resume_from_checkpoints: Optional[bool] = None auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False mean_resizing_embeddings: Optional[bool] = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None rl: Optional[RLType] = None
reward_model: Optional[bool] = None reward_model: Optional[bool] = None

View File

@@ -1053,9 +1053,12 @@ class ModelLoader:
if self.cfg.resize_token_embeddings_to_32x if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer) else len(self.tokenizer)
) )
if ( if hasattr(self.model, "get_input_embeddings") and (
hasattr(self.model, "get_input_embeddings") self.model.get_input_embeddings().num_embeddings < embeddings_len
and self.model.get_input_embeddings().num_embeddings != embeddings_len or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None: