From 2c1376d8c4ae7e0d21df7be641394f576a3a8db9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 3 Feb 2025 10:10:22 -0500 Subject: [PATCH] don't shrink embeddings unless told to --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 ++ src/axolotl/utils/models.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0607f8af1..d9df16943 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -667,6 +667,8 @@ class AxolotlInputConfig( auto_resume_from_checkpoints: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None mean_resizing_embeddings: Optional[bool] = False + # optionally shrink the embeddings when the tokenizer vocab size is smaller + shrink_embeddings: Optional[bool] = None rl: Optional[RLType] = None reward_model: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d46564f42..2437e2d47 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1053,9 +1053,12 @@ class ModelLoader: if self.cfg.resize_token_embeddings_to_32x else len(self.tokenizer) ) - if ( - hasattr(self.model, "get_input_embeddings") - and self.model.get_input_embeddings().num_embeddings != embeddings_len + if hasattr(self.model, "get_input_embeddings") and ( + self.model.get_input_embeddings().num_embeddings < embeddings_len + or ( + self.model.get_input_embeddings().num_embeddings > embeddings_len + and self.cfg.shrink_embeddings + ) ): resize_kwargs = {} if self.cfg.mean_resizing_embeddings is not None: