Fix untrained tokens (#1771)
* fix untrained reserved tokens * save model after fixing untrained embeddings * don't need fsdp conditional here
This commit is contained in:
150
src/axolotl/core/tokenizer_utils.py
Normal file
150
src/axolotl/core/tokenizer_utils.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
helper functions for fixing the embeddings/tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
||||||
|
"""
|
||||||
|
Many of the newer models have reserved tokens that are not trained.
|
||||||
|
"""
|
||||||
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
|
||||||
|
# Get untrained tokens
|
||||||
|
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
|
n_untrained = where_untrained.shape[0]
|
||||||
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
|
|
||||||
|
# Get set and actual tokens
|
||||||
|
where_untrained = where_untrained.tolist()
|
||||||
|
if len(where_untrained) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
|
where_untrained_set = frozenset(where_untrained)
|
||||||
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
|
# Remove None items in actual_bad_tokens
|
||||||
|
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
|
||||||
|
|
||||||
|
# Check if tokenizer and training datasets have bad tokens
|
||||||
|
if_bad_first = False
|
||||||
|
if_bad_second = False
|
||||||
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
if chat_template is not None:
|
||||||
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
# Check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check last 250
|
||||||
|
if not if_bad_second:
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if bad tokens exists!
|
||||||
|
if not if_bad_first and not if_bad_second:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Count all the possible bad tokens
|
||||||
|
final_counts = np.zeros(
|
||||||
|
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
def mapping(examples):
|
||||||
|
input_ids = examples["input_ids"]
|
||||||
|
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
|
||||||
|
np.add.at(final_counts, counter, 1)
|
||||||
|
|
||||||
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get sum of all items
|
||||||
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
|
|
||||||
|
# Remove bad tokens
|
||||||
|
sum_embedding -= torch.sum(
|
||||||
|
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
sum_lm_head -= torch.sum(
|
||||||
|
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find correct average by dividing by sum of trained tokens
|
||||||
|
mean_embedding = sum_embedding / n_trained
|
||||||
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
|
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
||||||
|
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
||||||
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
|
mean_embedding = (
|
||||||
|
mean_embedding.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
mean_lm_head = (
|
||||||
|
mean_lm_head.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
where_null = scaling.ravel() == 0
|
||||||
|
mean_embedding[where_null] = 0
|
||||||
|
mean_lm_head[where_null] = 0
|
||||||
|
|
||||||
|
# Set them to the mean
|
||||||
|
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
||||||
|
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -19,6 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.core.tokenizer_utils import fix_untrained_tokens
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -123,6 +124,13 @@ def train(
|
|||||||
total_num_steps,
|
total_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.fix_untrained_tokens:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
|
|||||||
@@ -658,6 +658,8 @@ class AxolotlInputConfig(
|
|||||||
chat_template: Optional[ChatTemplate] = None
|
chat_template: Optional[ChatTemplate] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user