Merge pull request #67 from OpenAccess-AI-Collective/refactor-tokenizer-load

load the tokenizer seperately from the model
This commit is contained in:
Wing Lian
2023-05-28 08:49:34 -04:00
committed by GitHub
2 changed files with 67 additions and 55 deletions

View File

@@ -5,7 +5,7 @@ import random
import signal
import sys
from pathlib import Path
from typing import Optional
from typing import Optional, List, Dict, Any, Union
import fire
import torch
@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
@@ -117,6 +117,10 @@ def choose_config(path: Path):
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
@@ -161,13 +165,30 @@ def train(
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(
logging.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
cfg.tokenizer_type,
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
@@ -192,10 +213,6 @@ def train(
model.save_pretrained(cfg.output_dir)
return
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
@@ -205,10 +222,6 @@ def train(
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
import bitsandbytes as bnb
import torch
import transformers
from torch import nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
def load_tokenizer(
base_model_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=cfg.trust_remote_code or False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=cfg.trust_remote_code or False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
return tokenizer
def load_model(
base_model,
base_model_config,
model_type,
tokenizer_type,
tokenizer,
cfg,
adapter="lora",
inference=False,
):
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
@@ -122,7 +157,7 @@ def load_model(
model_path = str(cache_model_path)
except:
model_path = cfg.base_model
model, tokenizer = load_llama_model_4bit_low_ram(
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
@@ -207,42 +242,6 @@ def load_model(
**model_kwargs,
)
if not tokenizer:
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
@@ -291,7 +290,7 @@ def load_model(
model.config.use_cache = False
# TODO resume_from_checkpoint handling
return model, tokenizer, lora_config
return model, lora_config
def load_adapter(model, cfg, adapter):