use axolotl contribs for fix_untrained_tokens (#2194) [skip ci]
* use axolotl contribs for fix_untrained_tokens * remove the module we're replacing * Add check for using fix_untrained_tokens
This commit is contained in:
@@ -60,3 +60,5 @@ antlr4-python3-runtime==4.13.2
|
|||||||
|
|
||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
|
axolotl-contribs-lgpl==0.0.1b2
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
"""Axolotl - Train and fine-tune large language models"""
|
"""Axolotl - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
import pkgutil
|
||||||
|
|
||||||
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.6.0"
|
||||||
|
|||||||
@@ -1,272 +0,0 @@
|
|||||||
"""
|
|
||||||
helper functions for fixing the embeddings/tokenizer
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
|
||||||
# GNU LESSER GENERAL PUBLIC LICENSE
|
|
||||||
# Version 3, 29 June 2007
|
|
||||||
#
|
|
||||||
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
|
||||||
# Everyone is permitted to copy and distribute verbatim copies
|
|
||||||
# of this license document, but changing it is not allowed.
|
|
||||||
|
|
||||||
import gc
|
|
||||||
import itertools
|
|
||||||
import logging
|
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
|
|
||||||
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Llama-3 for eg has untrained vectors in the base model.
|
|
||||||
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
|
|
||||||
We reset them to the mean of the rest of the tokens
|
|
||||||
"""
|
|
||||||
# Code licensed under LGPL
|
|
||||||
embedding_matrix = model.get_input_embeddings().weight
|
|
||||||
lm_head_matrix = model.get_output_embeddings().weight
|
|
||||||
chat_template = getattr(tokenizer, "chat_template", None)
|
|
||||||
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
|
|
||||||
|
|
||||||
# Ignore some model checks for now
|
|
||||||
if not ignored_tokenizer_names:
|
|
||||||
ignored_tokenizer_names = []
|
|
||||||
if (
|
|
||||||
model.config._name_or_path # pylint: disable=protected-access
|
|
||||||
in ignored_tokenizer_names
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Sometimes the sizes can be different like in vision models
|
|
||||||
# Ie <image> is in input, but not in output
|
|
||||||
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
|
|
||||||
embedding_matrix = embedding_matrix[:, :min_size]
|
|
||||||
lm_head_matrix = lm_head_matrix[:, :min_size]
|
|
||||||
|
|
||||||
# Get untrained tokens
|
|
||||||
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
|
|
||||||
# Check lm_head as well
|
|
||||||
|
|
||||||
# Does NOT work for Llama 3.1!!
|
|
||||||
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
|
|
||||||
|
|
||||||
# We instead check for repeated vectors
|
|
||||||
lm_head_where = torch.where(indicator_untrained1)[0]
|
|
||||||
lm_head_bad = lm_head_matrix[lm_head_where]
|
|
||||||
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
|
|
||||||
counter = Counter()
|
|
||||||
for row in lm_head_bad:
|
|
||||||
counter[hash(row.data.tobytes())] += 1
|
|
||||||
counter = Counter({k: c for k, c in counter.items() if c >= 2})
|
|
||||||
|
|
||||||
lm_head_where = lm_head_where.cpu().numpy()
|
|
||||||
final_bad_lm_head = []
|
|
||||||
for j, row in enumerate(lm_head_bad):
|
|
||||||
if hash(row.data.tobytes()) in counter:
|
|
||||||
final_bad_lm_head.append(lm_head_where[j])
|
|
||||||
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
|
|
||||||
indicator_untrained2[final_bad_lm_head] = True
|
|
||||||
|
|
||||||
# Combine both checks
|
|
||||||
indicator_untrained = indicator_untrained1 & indicator_untrained2
|
|
||||||
|
|
||||||
# Remove pad token possibility
|
|
||||||
if hasattr(tokenizer, "pad_token_id"):
|
|
||||||
pad_token_id = tokenizer.pad_token_id
|
|
||||||
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
|
|
||||||
indicator_untrained[pad_token_id] = False
|
|
||||||
|
|
||||||
where_untrained = torch.where(indicator_untrained)[0]
|
|
||||||
n_untrained = where_untrained.shape[0]
|
|
||||||
n_trained = embedding_matrix.shape[0] - n_untrained
|
|
||||||
|
|
||||||
# Get set and actual tokens
|
|
||||||
where_untrained = where_untrained.tolist()
|
|
||||||
if len(where_untrained) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove untrained indices where it's longer
|
|
||||||
where_untrained_set = frozenset(where_untrained)
|
|
||||||
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
|
||||||
# Remove None items in actual_bad_tokens
|
|
||||||
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
|
|
||||||
|
|
||||||
# Check if tokenizer and training datasets have bad tokens
|
|
||||||
if_bad_first = False
|
|
||||||
if_bad_second = False
|
|
||||||
# Check tokenizer's chat template for any untrained tokens
|
|
||||||
if chat_template is not None:
|
|
||||||
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
|
||||||
|
|
||||||
if isinstance(train_dataset, datasets.IterableDataset):
|
|
||||||
# Skip the check, since the code below assumes
|
|
||||||
# an indexable dataset
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check the first 250, last 250 input_ids
|
|
||||||
size_dataset = len(train_dataset)
|
|
||||||
size = min(size_dataset, 250)
|
|
||||||
for j in range(size):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
if_bad = any(item in where_untrained_set for item in input_ids)
|
|
||||||
if if_bad:
|
|
||||||
if_bad_second = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check last 250
|
|
||||||
if not if_bad_second:
|
|
||||||
left = max(size_dataset - 250, 0)
|
|
||||||
for j in range(left, size_dataset):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
if_bad = any(item in where_untrained_set for item in input_ids)
|
|
||||||
if if_bad:
|
|
||||||
if_bad_second = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check if bad tokens exists!
|
|
||||||
if not if_bad_first and not if_bad_second:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if lm_head / embed_token are trainable!
|
|
||||||
bad_not_trainable = False
|
|
||||||
if not embedding_matrix.requires_grad:
|
|
||||||
bad_not_trainable = True
|
|
||||||
if not lm_head_matrix.requires_grad:
|
|
||||||
bad_not_trainable = True
|
|
||||||
|
|
||||||
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
|
|
||||||
final_bad_items = []
|
|
||||||
|
|
||||||
# Re-check the first 250, last 250 input_ids
|
|
||||||
size_dataset = len(train_dataset)
|
|
||||||
size = min(size_dataset, 250)
|
|
||||||
for j in range(size):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
for item in input_ids:
|
|
||||||
if item in where_untrained_set:
|
|
||||||
final_bad_items.append(item)
|
|
||||||
|
|
||||||
# Re-check last 250
|
|
||||||
left = max(size_dataset - 250, 0)
|
|
||||||
for j in range(left, size_dataset):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
for item in input_ids:
|
|
||||||
if item in where_untrained_set:
|
|
||||||
final_bad_items.append(item)
|
|
||||||
|
|
||||||
# If no bad tokens, possibly chat template itself has issues?
|
|
||||||
if len(final_bad_items) == 0:
|
|
||||||
# Recheck 2000 and last 2000 items
|
|
||||||
size_dataset = len(train_dataset)
|
|
||||||
size = min(size_dataset, 2000)
|
|
||||||
for j in range(size):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
for item in input_ids:
|
|
||||||
if item in where_untrained_set:
|
|
||||||
final_bad_items.append(item)
|
|
||||||
|
|
||||||
# Re-check last 2000
|
|
||||||
left = max(size_dataset - 2000, 0)
|
|
||||||
for j in range(left, size_dataset):
|
|
||||||
input_ids = train_dataset[j]
|
|
||||||
if "input_ids" in input_ids:
|
|
||||||
input_ids = input_ids["input_ids"]
|
|
||||||
for item in input_ids:
|
|
||||||
if item in where_untrained_set:
|
|
||||||
final_bad_items.append(item)
|
|
||||||
|
|
||||||
# Most likely false signal!
|
|
||||||
if len(final_bad_items) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count all the possible bad tokens
|
|
||||||
final_counts = np.zeros(
|
|
||||||
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
|
|
||||||
)
|
|
||||||
|
|
||||||
def mapping(examples):
|
|
||||||
input_ids = examples["input_ids"]
|
|
||||||
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
|
|
||||||
np.add.at(final_counts, counter, 1)
|
|
||||||
|
|
||||||
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
|
||||||
|
|
||||||
# Get counts for untrained tokens
|
|
||||||
counts_untrained = final_counts[where_untrained]
|
|
||||||
# Identify untrained tokens seen in train_dataset
|
|
||||||
indices_seen_in_train = np.where(counts_untrained > 0)[0]
|
|
||||||
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
|
|
||||||
|
|
||||||
if len(tokens_to_update) == 0:
|
|
||||||
LOG.info(
|
|
||||||
"No untrained tokens found in train_dataset. No embeddings were modified."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Log the token IDs that are being rescaled
|
|
||||||
LOG.info(
|
|
||||||
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get sum of all items
|
|
||||||
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
|
||||||
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
|
||||||
|
|
||||||
# Remove bad tokens
|
|
||||||
sum_embedding -= torch.sum(
|
|
||||||
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
|
|
||||||
)
|
|
||||||
sum_lm_head -= torch.sum(
|
|
||||||
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find correct average by dividing by sum of trained tokens
|
|
||||||
mean_embedding = sum_embedding / n_trained
|
|
||||||
mean_lm_head = sum_lm_head / n_trained
|
|
||||||
|
|
||||||
# Compute scaling for tokens to update
|
|
||||||
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
|
|
||||||
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
|
||||||
|
|
||||||
# Prepare mean embeddings for tokens to update
|
|
||||||
mean_embedding_repeated = (
|
|
||||||
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
|
||||||
)
|
|
||||||
mean_lm_head_repeated = (
|
|
||||||
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update embeddings only for tokens seen in train_dataset
|
|
||||||
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
|
|
||||||
embedding_matrix.dtype
|
|
||||||
)
|
|
||||||
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
for _ in range(3):
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return
|
|
||||||
@@ -19,7 +19,9 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.core.tokenizer_utils import fix_untrained_tokens
|
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||||
|
fix_untrained_tokens,
|
||||||
|
)
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
@@ -13,18 +12,15 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestLlama(unittest.TestCase):
|
class TestLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models
|
Test case for Llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fft_trust_remote_code(self, temp_dir):
|
def test_fft_trust_remote_code(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -46,7 +42,8 @@ class TestLlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 8,
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -64,3 +61,46 @@ class TestLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
def test_fix_untrained_tokens(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"fix_untrained_tokens": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user