|
|
|
|
@@ -32,6 +32,45 @@ if TYPE_CHECKING:
|
|
|
|
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smart_tokenizer_and_embedding_resize(
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer,
|
|
|
|
|
model: transformers.PreTrainedModel,
|
|
|
|
|
resize_token_embeddings_multiple: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
"""Resize tokenizer and embedding.
|
|
|
|
|
|
|
|
|
|
Note: This function resizes the tokenizer to accommodate additional special tokens and the
|
|
|
|
|
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
|
|
|
|
|
have been added, the function computes the average embedding values of the existing embeddings
|
|
|
|
|
and sets those values for the new special token embeddings. This is done separately for the input
|
|
|
|
|
embeddings and output embeddings of the model.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
old_tokens = model.get_input_embeddings().weight.data.shape[0]
|
|
|
|
|
num_new_tokens = len(tokenizer) - old_tokens
|
|
|
|
|
embeddings_len = (
|
|
|
|
|
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
|
|
|
|
|
* resize_token_embeddings_multiple
|
|
|
|
|
if resize_token_embeddings_multiple
|
|
|
|
|
else len(tokenizer)
|
|
|
|
|
)
|
|
|
|
|
model.resize_token_embeddings(embeddings_len)
|
|
|
|
|
|
|
|
|
|
if num_new_tokens > 0:
|
|
|
|
|
input_embeddings = model.get_input_embeddings().weight.data
|
|
|
|
|
output_embeddings = model.get_output_embeddings().weight.data
|
|
|
|
|
|
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
|
|
|
dim=0, keepdim=True
|
|
|
|
|
)
|
|
|
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
|
|
|
dim=0, keepdim=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
|
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_tokenizer(cfg):
|
|
|
|
|
tokenizer_kwargs = {}
|
|
|
|
|
use_fast = True # this is the default
|
|
|
|
|
@@ -327,17 +366,16 @@ def load_model(
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
embeddings_len = (
|
|
|
|
|
math.ceil(len(tokenizer) / 32) * 32
|
|
|
|
|
if cfg.resize_token_embeddings_to_32x
|
|
|
|
|
else len(tokenizer)
|
|
|
|
|
smart_tokenizer_and_embedding_resize(
|
|
|
|
|
tokenizer,
|
|
|
|
|
model,
|
|
|
|
|
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
|
|
|
|
|
)
|
|
|
|
|
model.resize_token_embeddings(embeddings_len)
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
hasattr(model.config, "max_position_embeddings")
|
|
|
|
|
and model.config.max_position_embeddings
|
|
|
|
|
and cfg.sequence_len >= model.config.max_position_embeddings
|
|
|
|
|
and cfg.sequence_len > model.config.max_position_embeddings
|
|
|
|
|
):
|
|
|
|
|
LOG.warning(
|
|
|
|
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
|
|
|
|
|