Merge pull request #67 from OpenAccess-AI-Collective/refactor-tokenizer-load
load the tokenizer seperately from the model
This commit is contained in:
@@ -5,7 +5,7 @@ import random
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.models import load_model
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
@@ -117,6 +117,10 @@ def choose_config(path: Path):
|
||||
return chosen_file
|
||||
|
||||
|
||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def train(
|
||||
config: Path = Path("configs/"),
|
||||
prepare_ds_only: bool = False,
|
||||
@@ -161,13 +165,30 @@ def train(
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(
|
||||
cfg.base_model_config,
|
||||
cfg.tokenizer_type,
|
||||
cfg
|
||||
)
|
||||
|
||||
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
|
||||
if prepare_ds_only:
|
||||
logging.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
# Load the model and tokenizer
|
||||
logging.info("loading model, tokenizer, and peft_config...")
|
||||
model, tokenizer, peft_config = load_model(
|
||||
logging.info("loading model and peft_config...")
|
||||
model, peft_config = load_model(
|
||||
cfg.base_model,
|
||||
cfg.base_model_config,
|
||||
cfg.model_type,
|
||||
cfg.tokenizer_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
inference=("inference" in kwargs),
|
||||
@@ -192,10 +213,6 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
return
|
||||
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
|
||||
if cfg.debug:
|
||||
logging.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
@@ -205,10 +222,6 @@ def train(
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
if prepare_ds_only:
|
||||
logging.info("Finished preparing dataset. Exiting...")
|
||||
return
|
||||
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
|
||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, v in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_model(
|
||||
base_model,
|
||||
base_model_config,
|
||||
model_type,
|
||||
tokenizer_type,
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter="lora",
|
||||
inference=False,
|
||||
):
|
||||
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
|
||||
# type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
tokenizer = None
|
||||
is_llama_derived_model = "llama" in base_model or (
|
||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
||||
)
|
||||
@@ -122,7 +157,7 @@ def load_model(
|
||||
model_path = str(cache_model_path)
|
||||
except:
|
||||
model_path = cfg.base_model
|
||||
model, tokenizer = load_llama_model_4bit_low_ram(
|
||||
model, _ = load_llama_model_4bit_low_ram(
|
||||
base_model_config if base_model_config else base_model,
|
||||
model_path,
|
||||
device_map=cfg.device_map,
|
||||
@@ -207,42 +242,6 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if not tokenizer:
|
||||
try:
|
||||
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
)
|
||||
else:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
base_model_config,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
)
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_config,
|
||||
trust_remote_code=True if cfg.trust_remote_code is True else False,
|
||||
)
|
||||
|
||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, v in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
@@ -291,7 +290,7 @@ def load_model(
|
||||
model.config.use_cache = False
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, tokenizer, lora_config
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter):
|
||||
|
||||
Reference in New Issue
Block a user