diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4831da3c8..16cf312ce 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -549,6 +549,7 @@ class AxolotlInputConfig( resume_from_checkpoint: Optional[str] = None auto_resume_from_checkpoints: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None + mean_resizing_embeddings: Optional[bool] = False rl: Optional[RLType] = None reward_model: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8b433c366..97844a5bf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1042,7 +1042,10 @@ class ModelLoader: hasattr(self.model, "get_input_embeddings") and self.model.get_input_embeddings().num_embeddings < embeddings_len ): - self.model.resize_token_embeddings(embeddings_len) + resize_kwargs = {} + if self.cfg.mean_resizing_embeddings is not None: + resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) else: self.model.tie_weights()