Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
31079cd5fd smart resize embeddings
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-14 23:44:15 -04:00
2 changed files with 47 additions and 9 deletions

View File

@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# resize the model embeddings when new tokens are added to multiples of N
# multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_multiple:
# whether you are training a 4-bit GPTQ quantized model
gptq: true

View File

@@ -32,6 +32,45 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
resize_token_embeddings_multiple: Optional[int] = None,
):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -327,17 +366,16 @@ def load_model(
**model_kwargs,
)
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
smart_tokenizer_and_embedding_resize(
tokenizer,
model,
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
)
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"