Compare commits
1 Commits
NanoCode01
...
embeddings
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31079cd5fd |
@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
|
|||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||||
tokenizer_use_fast:
|
tokenizer_use_fast:
|
||||||
# resize the model embeddings when new tokens are added to multiples of 32
|
# resize the model embeddings when new tokens are added to multiples of N
|
||||||
# this is reported to improve training speed on some models
|
# multiples of 32 are reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_multiple:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
|
|||||||
@@ -32,6 +32,45 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def smart_tokenizer_and_embedding_resize(
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer,
|
||||||
|
model: transformers.PreTrainedModel,
|
||||||
|
resize_token_embeddings_multiple: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Resize tokenizer and embedding.
|
||||||
|
|
||||||
|
Note: This function resizes the tokenizer to accommodate additional special tokens and the
|
||||||
|
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
|
||||||
|
have been added, the function computes the average embedding values of the existing embeddings
|
||||||
|
and sets those values for the new special token embeddings. This is done separately for the input
|
||||||
|
embeddings and output embeddings of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_tokens = model.get_input_embeddings().weight.data.shape[0]
|
||||||
|
num_new_tokens = len(tokenizer) - old_tokens
|
||||||
|
embeddings_len = (
|
||||||
|
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
|
||||||
|
* resize_token_embeddings_multiple
|
||||||
|
if resize_token_embeddings_multiple
|
||||||
|
else len(tokenizer)
|
||||||
|
)
|
||||||
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
input_embeddings = model.get_input_embeddings().weight.data
|
||||||
|
output_embeddings = model.get_output_embeddings().weight.data
|
||||||
|
|
||||||
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||||
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
@@ -327,17 +366,16 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings_len = (
|
smart_tokenizer_and_embedding_resize(
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
tokenizer,
|
||||||
if cfg.resize_token_embeddings_to_32x
|
model,
|
||||||
else len(tokenizer)
|
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
|
||||||
)
|
)
|
||||||
model.resize_token_embeddings(embeddings_len)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||||
|
|||||||
Reference in New Issue
Block a user