update from unsloth-zoo with additional fixes (#2122)
only update tokens seen in the train dataset, log them out explicitly
This commit is contained in:
@@ -18,21 +18,79 @@ helper functions for fixing the embeddings/tokenizer
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
|
||||||
|
|
||||||
@torch.inference_mode
|
|
||||||
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
@torch.inference_mode()
|
||||||
|
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
|
||||||
|
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Many of the newer models have reserved tokens that are not trained.
|
Llama-3 for eg has untrained vectors in the base model.
|
||||||
|
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
|
||||||
|
We reset them to the mean of the rest of the tokens
|
||||||
"""
|
"""
|
||||||
|
# Code licensed under LGPL
|
||||||
embedding_matrix = model.get_input_embeddings().weight
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
lm_head_matrix = model.get_output_embeddings().weight
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
|
||||||
|
|
||||||
|
# Ignore some model checks for now
|
||||||
|
if not ignored_tokenizer_names:
|
||||||
|
ignored_tokenizer_names = []
|
||||||
|
if (
|
||||||
|
model.config._name_or_path # pylint: disable=protected-access
|
||||||
|
in ignored_tokenizer_names
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sometimes the sizes can be different like in vision models
|
||||||
|
# Ie <image> is in input, but not in output
|
||||||
|
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
|
||||||
|
embedding_matrix = embedding_matrix[:, :min_size]
|
||||||
|
lm_head_matrix = lm_head_matrix[:, :min_size]
|
||||||
|
|
||||||
# Get untrained tokens
|
# Get untrained tokens
|
||||||
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
# Check lm_head as well
|
||||||
|
|
||||||
|
# Does NOT work for Llama 3.1!!
|
||||||
|
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
|
||||||
|
|
||||||
|
# We instead check for repeated vectors
|
||||||
|
lm_head_where = torch.where(indicator_untrained1)[0]
|
||||||
|
lm_head_bad = lm_head_matrix[lm_head_where]
|
||||||
|
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
|
||||||
|
counter = Counter()
|
||||||
|
for row in lm_head_bad:
|
||||||
|
counter[hash(row.data.tobytes())] += 1
|
||||||
|
counter = Counter({k: c for k, c in counter.items() if c >= 2})
|
||||||
|
|
||||||
|
lm_head_where = lm_head_where.cpu().numpy()
|
||||||
|
final_bad_lm_head = []
|
||||||
|
for j, row in enumerate(lm_head_bad):
|
||||||
|
if hash(row.data.tobytes()) in counter:
|
||||||
|
final_bad_lm_head.append(lm_head_where[j])
|
||||||
|
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
|
||||||
|
indicator_untrained2[final_bad_lm_head] = True
|
||||||
|
|
||||||
|
# Combine both checks
|
||||||
|
indicator_untrained = indicator_untrained1 & indicator_untrained2
|
||||||
|
|
||||||
|
# Remove pad token possibility
|
||||||
|
if hasattr(tokenizer, "pad_token_id"):
|
||||||
|
pad_token_id = tokenizer.pad_token_id
|
||||||
|
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
|
||||||
|
indicator_untrained[pad_token_id] = False
|
||||||
|
|
||||||
where_untrained = torch.where(indicator_untrained)[0]
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
n_untrained = where_untrained.shape[0]
|
n_untrained = where_untrained.shape[0]
|
||||||
n_trained = embedding_matrix.shape[0] - n_untrained
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
@@ -40,10 +98,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
# Get set and actual tokens
|
# Get set and actual tokens
|
||||||
where_untrained = where_untrained.tolist()
|
where_untrained = where_untrained.tolist()
|
||||||
if len(where_untrained) == 0:
|
if len(where_untrained) == 0:
|
||||||
return False
|
return
|
||||||
|
|
||||||
# Remove untrained indices where it's longer
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
where_untrained_set = frozenset(where_untrained)
|
where_untrained_set = frozenset(where_untrained)
|
||||||
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
# Remove None items in actual_bad_tokens
|
# Remove None items in actual_bad_tokens
|
||||||
@@ -53,10 +110,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
if_bad_first = False
|
if_bad_first = False
|
||||||
if_bad_second = False
|
if_bad_second = False
|
||||||
# Check tokenizer's chat template for any untrained tokens
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
chat_template = getattr(tokenizer, "chat_template", None)
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
if isinstance(train_dataset, datasets.IterableDataset):
|
||||||
|
# Skip the check, since the code below assumes
|
||||||
|
# an indexable dataset
|
||||||
|
return
|
||||||
|
|
||||||
# Check the first 250, last 250 input_ids
|
# Check the first 250, last 250 input_ids
|
||||||
size_dataset = len(train_dataset)
|
size_dataset = len(train_dataset)
|
||||||
size = min(size_dataset, 250)
|
size = min(size_dataset, 250)
|
||||||
@@ -83,7 +144,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
# Check if bad tokens exists!
|
# Check if bad tokens exists!
|
||||||
if not if_bad_first and not if_bad_second:
|
if not if_bad_first and not if_bad_second:
|
||||||
return False
|
return
|
||||||
|
|
||||||
|
# Check if lm_head / embed_token are trainable!
|
||||||
|
bad_not_trainable = False
|
||||||
|
if not embedding_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
if not lm_head_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
|
||||||
|
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
|
||||||
|
final_bad_items = []
|
||||||
|
|
||||||
|
# Re-check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 250
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# If no bad tokens, possibly chat template itself has issues?
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
# Recheck 2000 and last 2000 items
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 2000)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 2000
|
||||||
|
left = max(size_dataset - 2000, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Most likely false signal!
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
|
||||||
|
)
|
||||||
|
|
||||||
# Count all the possible bad tokens
|
# Count all the possible bad tokens
|
||||||
final_counts = np.zeros(
|
final_counts = np.zeros(
|
||||||
@@ -97,6 +220,23 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get counts for untrained tokens
|
||||||
|
counts_untrained = final_counts[where_untrained]
|
||||||
|
# Identify untrained tokens seen in train_dataset
|
||||||
|
indices_seen_in_train = np.where(counts_untrained > 0)[0]
|
||||||
|
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
|
||||||
|
|
||||||
|
if len(tokens_to_update) == 0:
|
||||||
|
LOG.info(
|
||||||
|
"No untrained tokens found in train_dataset. No embeddings were modified."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log the token IDs that are being rescaled
|
||||||
|
LOG.info(
|
||||||
|
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
|
||||||
|
)
|
||||||
|
|
||||||
# Get sum of all items
|
# Get sum of all items
|
||||||
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
@@ -113,38 +253,26 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
mean_embedding = sum_embedding / n_trained
|
mean_embedding = sum_embedding / n_trained
|
||||||
mean_lm_head = sum_lm_head / n_trained
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
# Compute scaling for tokens to update
|
||||||
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
|
||||||
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
mean_embedding = (
|
|
||||||
mean_embedding.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
mean_lm_head = (
|
|
||||||
mean_lm_head.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
where_null = scaling.ravel() == 0
|
|
||||||
mean_embedding[where_null] = 0
|
|
||||||
mean_lm_head[where_null] = 0
|
|
||||||
|
|
||||||
# Set them to the mean
|
# Prepare mean embeddings for tokens to update
|
||||||
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
mean_embedding_repeated = (
|
||||||
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
mean_lm_head_repeated = (
|
||||||
|
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update embeddings only for tokens seen in train_dataset
|
||||||
|
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
|
||||||
|
embedding_matrix.dtype
|
||||||
|
)
|
||||||
|
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
return
|
||||||
return True
|
|
||||||
|
|||||||
Reference in New Issue
Block a user