From 31079cd5fd1e5c0ceaeca15c4275ef088f93d3dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Aug 2023 10:15:10 -0400 Subject: [PATCH] smart resize embeddings --- README.md | 6 ++--- src/axolotl/utils/models.py | 50 ++++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 72cd71f22..975393799 100644 --- a/README.md +++ b/README.md @@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer trust_remote_code: # use_fast option for tokenizer loading from_pretrained, default to True tokenizer_use_fast: -# resize the model embeddings when new tokens are added to multiples of 32 -# this is reported to improve training speed on some models -resize_token_embeddings_to_32x: +# resize the model embeddings when new tokens are added to multiples of N +# multiples of 32 are reported to improve training speed on some models +resize_token_embeddings_multiple: # whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ce2d14f47..6c2f1f31c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -32,6 +32,45 @@ if TYPE_CHECKING: from axolotl.utils.dict import DictDefault # noqa: F401 +def smart_tokenizer_and_embedding_resize( + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + resize_token_embeddings_multiple: Optional[int] = None, +): + """Resize tokenizer and embedding. + + Note: This function resizes the tokenizer to accommodate additional special tokens and the + embedding matrix of the model to match the new size of the tokenizer. If any new special tokens + have been added, the function computes the average embedding values of the existing embeddings + and sets those values for the new special token embeddings. This is done separately for the input + embeddings and output embeddings of the model. + """ + + old_tokens = model.get_input_embeddings().weight.data.shape[0] + num_new_tokens = len(tokenizer) - old_tokens + embeddings_len = ( + math.ceil(len(tokenizer) / resize_token_embeddings_multiple) + * resize_token_embeddings_multiple + if resize_token_embeddings_multiple + else len(tokenizer) + ) + model.resize_token_embeddings(embeddings_len) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True + ) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + def load_tokenizer(cfg): tokenizer_kwargs = {} use_fast = True # this is the default @@ -327,17 +366,16 @@ def load_model( **model_kwargs, ) - embeddings_len = ( - math.ceil(len(tokenizer) / 32) * 32 - if cfg.resize_token_embeddings_to_32x - else len(tokenizer) + smart_tokenizer_and_embedding_resize( + tokenizer, + model, + resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple, ) - model.resize_token_embeddings(embeddings_len) if ( hasattr(model.config, "max_position_embeddings") and model.config.max_position_embeddings - and cfg.sequence_len >= model.config.max_position_embeddings + and cfg.sequence_len > model.config.max_position_embeddings ): LOG.warning( f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"