Merge pull request #67 from OpenAccess-AI-Collective/refactor-tokenizer-load

load the tokenizer seperately from the model
This commit is contained in:
Wing Lian
2023-05-28 08:49:34 -04:00
committed by GitHub
2 changed files with 67 additions and 55 deletions

View File

@@ -5,7 +5,7 @@ import random
import signal import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, List, Dict, Any, Union
import fire import fire
import torch import torch
@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
@@ -117,6 +117,10 @@ def choose_config(path: Path):
return chosen_file return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def train( def train(
config: Path = Path("configs/"), config: Path = Path("configs/"),
prepare_ds_only: bool = False, prepare_ds_only: bool = False,
@@ -161,13 +165,30 @@ def train(
validate_config(cfg) validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer # Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...") logging.info("loading model and peft_config...")
model, tokenizer, peft_config = load_model( model, peft_config = load_model(
cfg.base_model, cfg.base_model,
cfg.base_model_config, cfg.base_model_config,
cfg.model_type, cfg.model_type,
cfg.tokenizer_type, tokenizer,
cfg, cfg,
adapter=cfg.adapter, adapter=cfg.adapter,
inference=("inference" in kwargs), inference=("inference" in kwargs),
@@ -192,10 +213,6 @@ def train(
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
return return
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.debug: if cfg.debug:
logging.info("check_dataset_labels...") logging.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
@@ -205,10 +222,6 @@ def train(
tokenizer, tokenizer,
) )
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False model.config.use_cache = False

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from torch import nn
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
def load_tokenizer(
base_model_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=cfg.trust_remote_code or False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=cfg.trust_remote_code or False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
return tokenizer
def load_model( def load_model(
base_model, base_model,
base_model_config, base_model_config,
model_type, model_type,
tokenizer_type, tokenizer,
cfg, cfg,
adapter="lora", adapter="lora",
inference=False, inference=False,
): ):
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or ( is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower() cfg.model_type and "llama" in cfg.model_type.lower()
) )
@@ -122,7 +157,7 @@ def load_model(
model_path = str(cache_model_path) model_path = str(cache_model_path)
except: except:
model_path = cfg.base_model model_path = cfg.base_model
model, tokenizer = load_llama_model_4bit_low_ram( model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model, base_model_config if base_model_config else base_model,
model_path, model_path,
device_map=cfg.device_map, device_map=cfg.device_map,
@@ -207,42 +242,6 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
if not tokenizer:
try:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
except:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
for k, v in cfg.special_tokens.items():
tokenizer.add_special_tokens({k: v})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
embeddings_len = math.ceil(len(tokenizer) / 32) * 32 embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len) model.resize_token_embeddings(embeddings_len)
@@ -291,7 +290,7 @@ def load_model(
model.config.use_cache = False model.config.use_cache = False
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, tokenizer, lora_config return model, lora_config
def load_adapter(model, cfg, adapter): def load_adapter(model, cfg, adapter):