diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ca56e79d8..ab5bbc267 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -302,7 +302,10 @@ def load_model( if cfg.resize_token_embeddings_to_32x else len(tokenizer) ) - model.resize_token_embeddings(embeddings_len) + if model.get_input_embeddings().num_embeddings < embeddings_len: + model.resize_token_embeddings(embeddings_len) + else: + model.tie_weights() if ( hasattr(model.config, "max_position_embeddings")