load the tokenizer seperately from the model

This commit is contained in:
Wing Lian
2023-05-26 07:29:35 -04:00
parent bbfc333a01
commit 32e6fe9286
2 changed files with 62 additions and 54 deletions

View File

@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
@@ -161,13 +161,30 @@ def train(
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(
logging.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
cfg.tokenizer_type,
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
@@ -192,10 +209,6 @@ def train(
model.save_pretrained(cfg.output_dir)
return
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
@@ -205,10 +218,6 @@ def train(
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
import bitsandbytes as bnb
import torch
import transformers
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
def load_tokenizer(
base_model_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
return tokenizer
def load_model(
base_model,
base_model_config,
model_type,
tokenizer_type,
tokenizer,
cfg,
adapter="lora",
inference=False,
):
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
@@ -122,7 +157,7 @@ def load_model(
model_path = str(cache_model_path)
except:
model_path = cfg.base_model
model, tokenizer = load_llama_model_4bit_low_ram(
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
@@ -207,42 +242,6 @@ def load_model(
**model_kwargs,
)
if not tokenizer:
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
@@ -291,7 +290,7 @@ def load_model(
model.config.use_cache = False
# TODO resume_from_checkpoint handling
return model, tokenizer, lora_config
return model, lora_config
def load_adapter(model, cfg, adapter):