diff --git a/scripts/finetune.py b/scripts/finetune.py index 0c8727401..3d72fb1d9 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.models import load_model +from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars @@ -161,13 +161,30 @@ def train( validate_config(cfg) + # load the tokenizer first + logging.info("loading tokenizer...") + tokenizer = load_tokenizer( + cfg.base_model_config, + cfg.tokenizer_type, + cfg + ) + + if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + + if prepare_ds_only: + logging.info("Finished preparing dataset. Exiting...") + return + # Load the model and tokenizer - logging.info("loading model, tokenizer, and peft_config...") - model, tokenizer, peft_config = load_model( + logging.info("loading model and peft_config...") + model, peft_config = load_model( cfg.base_model, cfg.base_model_config, cfg.model_type, - cfg.tokenizer_type, + tokenizer, cfg, adapter=cfg.adapter, inference=("inference" in kwargs), @@ -192,10 +209,6 @@ def train( model.save_pretrained(cfg.output_dir) return - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) - if cfg.debug: logging.info("check_dataset_labels...") check_dataset_labels( @@ -205,10 +218,6 @@ def train( tokenizer, ) - if prepare_ds_only: - logging.info("Finished preparing dataset. Exiting...") - return - trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) model.config.use_cache = False diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 939a312d5..6721537c2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING import bitsandbytes as bnb import torch import transformers -from torch import nn from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -34,20 +33,56 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer +def load_tokenizer( + base_model_config, + tokenizer_type, + cfg, +): + if tokenizer_type: + tokenizer = getattr(transformers, tokenizer_type).from_pretrained( + base_model_config, + trust_remote_code=True if cfg.trust_remote_code is True else False, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + base_model_config, + trust_remote_code=True if cfg.trust_remote_code is True else False, + ) + + logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + + if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: + tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN + + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + if cfg.special_tokens: + for k, v in cfg.special_tokens.items(): + tokenizer.add_special_tokens({k: v}) + if cfg.tokens: + tokenizer.add_tokens(list(cfg.tokens)) + + return tokenizer + + def load_model( base_model, base_model_config, model_type, - tokenizer_type, + tokenizer, cfg, adapter="lora", inference=False, ): - # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit - tokenizer = None is_llama_derived_model = "llama" in base_model or ( cfg.model_type and "llama" in cfg.model_type.lower() ) @@ -122,7 +157,7 @@ def load_model( model_path = str(cache_model_path) except: model_path = cfg.base_model - model, tokenizer = load_llama_model_4bit_low_ram( + model, _ = load_llama_model_4bit_low_ram( base_model_config if base_model_config else base_model, model_path, device_map=cfg.device_map, @@ -207,42 +242,6 @@ def load_model( **model_kwargs, ) - if not tokenizer: - try: - if is_llama_derived_model and "LlamaTokenizer" in globals(): - tokenizer = LlamaTokenizer.from_pretrained( - base_model_config, - trust_remote_code=True if cfg.trust_remote_code is True else False, - ) - else: - tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - base_model_config, - trust_remote_code=True if cfg.trust_remote_code is True else False, - ) - except: - tokenizer = AutoTokenizer.from_pretrained( - base_model_config, - trust_remote_code=True if cfg.trust_remote_code is True else False, - ) - - logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: - tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - if cfg.special_tokens: - for k, v in cfg.special_tokens.items(): - tokenizer.add_special_tokens({k: v}) - if cfg.tokens: - tokenizer.add_tokens(list(cfg.tokens)) - embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) @@ -291,7 +290,7 @@ def load_model( model.config.use_cache = False # TODO resume_from_checkpoint handling - return model, tokenizer, lora_config + return model, lora_config def load_adapter(model, cfg, adapter):