From e1e0556c9951ef53ee627310bf3d248908fdf39a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Oct 2024 17:02:04 -0400 Subject: [PATCH] add option for resizing embeddings when adding new tokens (#2000) * add option for resizing embeddings when adding new tokens * let's just be opinonated about this setting and set it to False --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/models.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4831da3c8..16cf312ce 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -549,6 +549,7 @@ class AxolotlInputConfig( resume_from_checkpoint: Optional[str] = None auto_resume_from_checkpoints: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None + mean_resizing_embeddings: Optional[bool] = False rl: Optional[RLType] = None reward_model: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8b433c366..97844a5bf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1042,7 +1042,10 @@ class ModelLoader: hasattr(self.model, "get_input_embeddings") and self.model.get_input_embeddings().num_embeddings < embeddings_len ): - self.model.resize_token_embeddings(embeddings_len) + resize_kwargs = {} + if self.cfg.mean_resizing_embeddings is not None: + resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) else: self.model.tie_weights()