don't shrink embeddings unless told to

This commit is contained in:
Wing Lian
2025-02-03 10:10:22 -05:00
parent 3c7517fd55
commit 2c1376d8c4
2 changed files with 8 additions and 3 deletions

View File

@@ -667,6 +667,8 @@ class AxolotlInputConfig(
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None
reward_model: Optional[bool] = None

View File

@@ -1053,9 +1053,12 @@ class ModelLoader:
if self.cfg.resize_token_embeddings_to_32x
else len(self.tokenizer)
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings != embeddings_len
if hasattr(self.model, "get_input_embeddings") and (
self.model.get_input_embeddings().num_embeddings < embeddings_len
or (
self.model.get_input_embeddings().num_embeddings > embeddings_len
and self.cfg.shrink_embeddings
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: