From 125cccb7864219c26b13a45966f46b9c16e1f1ff Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 29 Aug 2023 05:37:53 -0700 Subject: [PATCH] Refactor train cfg cli (#499) * wip to cleanup cfg cli options * fix launcher * fix cli args --- scripts/finetune.py | 126 +++++++++++++++++++++++------------- src/axolotl/utils/models.py | 40 ++++++------ 2 files changed, 101 insertions(+), 65 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index d02448ec2..454a627a1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,11 +6,13 @@ import os import random import signal import sys +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union import fire import torch +import transformers import yaml # add src to the pythonpath so we don't need to pip install this @@ -22,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process -from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.models import load_model, load_model_config, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars @@ -37,6 +39,20 @@ LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +@dataclass +class TrainerCliArgs: + """ + dataclass representing the various non-training arguments + """ + + debug: bool = field(default=False) + inference: bool = field(default=False) + merge_lora: bool = field(default=False) + prepare_ds_only: bool = field(default=False) + prompter: Optional[str] = field(default=None) + shard: bool = field(default=False) + + def print_axolotl_text_art(): ascii_art = """ dP dP dP @@ -61,6 +77,8 @@ def get_multi_line_input() -> Optional[str]: def do_inference(cfg, model, tokenizer, prompter: Optional[str]): + if prompter == "None": + prompter = None default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} for token, symbol in default_tokens.items(): @@ -158,45 +176,20 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b def train( - config: Path = Path("configs/"), - prepare_ds_only: bool = False, - **kwargs, + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, ): - print_axolotl_text_art() - if Path(config).is_dir(): - config = choose_config(config) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - validate_config(cfg) - - normalize_config(cfg) - - setup_wandb_env_vars(cfg) - # load the tokenizer first LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) - if ( - check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference + if not ( + cli_args.shard or cli_args.merge_lora or cli_args.inference ): # don't need to load dataset for these train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) - if cfg.debug or "debug" in kwargs: + if cli_args.debug or cfg.debug: LOG.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( @@ -205,17 +198,17 @@ def train( tokenizer, ) - if prepare_ds_only: + if cli_args.prepare_ds_only: LOG.info("Finished preparing dataset. Exiting...") return # Load the model and tokenizer LOG.info("loading model and (optionally) peft_config...") - model, peft_config = load_model(cfg, tokenizer) + model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) safe_serialization = cfg.save_safetensors is True - if "merge_lora" in kwargs and cfg.adapter is not None: + if cli_args.merge_lora and cfg.adapter is not None: LOG.info("running merge of LoRA with base model") model = model.merge_and_unload() model.to(dtype=torch.float16) @@ -229,18 +222,13 @@ def train( tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) return - if cfg.inference: - LOG.info("calling do_inference function") - prompter: Optional[str] = "AlpacaPrompter" - if "prompter" in kwargs: - if kwargs["prompter"] == "None": - prompter = None - else: - prompter = kwargs["prompter"] - do_inference(cfg, model, tokenizer, prompter=prompter) + if cli_args.inference: + LOG.debug("Running inference on model") + do_inference(cfg, model, tokenizer, prompter=cli_args.prompter) return - if "shard" in kwargs: + if cli_args.shard: + LOG.debug("Re-saving model w/ sharding") model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) return @@ -322,5 +310,51 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) +def load_cfg(config: Path = Path("examples/"), **kwargs): + if Path(config).is_dir(): + config = choose_config(config) + + # load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # if there are any options passed in the cli, if it is something that seems valid from the yaml, + # then overwrite the value + cfg_keys = cfg.keys() + for k, _ in kwargs.items(): + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or not cfg.strict: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] + + model_config = load_model_config(cfg) + + # figure out if the model is llama + cfg.is_llama_derived_model = ( + (hasattr(model_config, "model_type") and model_config.model_type == "llama") + or cfg.is_llama_derived_model + or "llama" in cfg.base_model + or (cfg.model_type and "llama" in cfg.model_type.lower()) + ) + validate_config(cfg) + + normalize_config(cfg) + + setup_wandb_env_vars(cfg) + return cfg + + +def do_train(config: Path = Path("examples/"), **kwargs): + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + train(cfg=parsed_cfg, cli_args=parsed_cli_args) + + if __name__ == "__main__": - fire.Fire(train) + fire.Fire(do_train) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d0e5128ef..4b9c79d84 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -5,12 +5,13 @@ import logging import math import os from pathlib import Path -from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 +from typing import Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer +from peft import PeftConfig from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, @@ -23,13 +24,17 @@ from transformers import ( # noqa: F401 from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") -if TYPE_CHECKING: - from peft import PeftConfig # noqa: F401 - from axolotl.utils.dict import DictDefault # noqa: F401 +def load_model_config(cfg): + model_config_name = cfg.base_model_config or cfg.base_model + trust_remote_code: bool = False or cfg.trust_remote_code + return AutoConfig.from_pretrained( + model_config_name, trust_remote_code=trust_remote_code + ) def load_tokenizer(cfg): @@ -86,8 +91,10 @@ def load_tokenizer(cfg): def load_model( - cfg, tokenizer -): # type: (DictDefault, PreTrainedTokenizerBase) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + inference: bool = False, +) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. """ @@ -97,14 +104,9 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit - cfg.is_llama_derived_model = ( - "llama" in base_model - or (cfg.model_type and "llama" in cfg.model_type.lower()) - or cfg.is_llama_derived_model - ) if cfg.is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and not cfg.inference: + if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, ) @@ -146,7 +148,7 @@ def load_model( if ( cfg.is_llama_derived_model and (cfg.max_packed_sequence_len or cfg.sample_packing) - and not cfg.inference + and not inference ): from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask @@ -424,15 +426,15 @@ def load_model( return model, lora_config -def load_adapter(model, cfg, adapter): - # type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] +def load_adapter(model, cfg, adapter, inference=False): + # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] if adapter is None: return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if adapter in ["lora", "qlora"]: - return load_lora(model, cfg) + return load_lora(model, cfg, inference=inference) if adapter == "llama-adapter": return load_llama_adapter(model, cfg) @@ -478,8 +480,8 @@ def find_all_linear_names(model): return list(lora_module_names) -def load_lora(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] +def load_lora(model, cfg, inference=False): + # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] from peft import LoraConfig, PeftModel, get_peft_model @@ -506,7 +508,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, - is_trainable=not cfg.inference, + is_trainable=(not inference), ) else: model = get_peft_model(model, lora_config)