From 861cecac2abb91b728aea20df7c47f23d6b2d5b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Sep 2023 01:43:52 -0400 Subject: [PATCH] refactor scripts/finetune.py into new cli modules (#550) * refactor scripts/finetune.py into new cli modules * continue to support scripts/finetune.py * update readme with updated cli commands * Update scripts/finetune.py Co-authored-by: NanoCode012 --------- Co-authored-by: NanoCode012 --- README.md | 22 +-- scripts/finetune.py | 273 +++------------------------------- src/axolotl/cli/__init__.py | 249 +++++++++++++++++++++++++++++++ src/axolotl/cli/inference.py | 26 ++++ src/axolotl/cli/merge_lora.py | 26 ++++ src/axolotl/cli/shard.py | 41 +++++ src/axolotl/cli/train.py | 35 +++++ 7 files changed, 407 insertions(+), 265 deletions(-) create mode 100644 src/axolotl/cli/__init__.py create mode 100644 src/axolotl/cli/inference.py create mode 100644 src/axolotl/cli/merge_lora.py create mode 100644 src/axolotl/cli/shard.py create mode 100644 src/axolotl/cli/train.py diff --git a/README.md b/README.md index 0ba61eb58..4e213018d 100644 --- a/README.md +++ b/README.md @@ -76,11 +76,11 @@ pip3 install -e .[flash-attn] pip3 install -U git+https://github.com/huggingface/peft.git # finetune lora -accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml +accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml # inference -accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \ - --inference --lora_model_dir="./lora-out" +accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ + --lora_model_dir="./lora-out" ``` ## Installation @@ -674,14 +674,14 @@ strict: Run ```bash -accelerate launch scripts/finetune.py your_config.yml +accelerate launch -m axolotl.cli.train your_config.yml ``` #### Multi-GPU You can optionally pre-tokenize dataset with the following before finetuning: ```bash -CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only +CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only ``` ##### Config @@ -720,16 +720,16 @@ Pass the appropriate flag to the train command: - Pretrained LORA: ```bash - --inference --lora_model_dir="./lora-output-dir" + python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" ``` - Full weights finetune: ```bash - --inference --base_model="./completed-model" + python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model" ``` - Full weights finetune w/ a prompt from a text file: ```bash - cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \ - --base_model="./completed-model" --inference --prompter=None --load_in_8bit=True + cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ + --base_model="./completed-model" --prompter=None --load_in_8bit=True ``` ### Merge LORA to base @@ -737,13 +737,13 @@ Pass the appropriate flag to the train command: Add below flag to train command above ```bash ---merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False +python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False ``` If you run out of CUDA memory, you can try to merge in system RAM with ```bash -CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ... +CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ... ``` ## Common Errors 🧰 diff --git a/scripts/finetune.py b/scripts/finetune.py index c149ad073..7b6751e31 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,269 +1,34 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" - -import importlib import logging -import os -import random -import sys from pathlib import Path -from typing import Any, Dict, List, Optional, Union import fire -import torch import transformers -import yaml -# add src to the pythonpath so we don't need to pip install this -from accelerate.commands.config import config_args -from art import text2art -from transformers import GenerationConfig, TextStreamer +from axolotl.cli import ( + check_accelerate_default_config, + do_inference, + do_merge_lora, + load_cfg, + load_datasets, + print_axolotl_text_art, +) +from axolotl.cli.shard import shard +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta, train -from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.data import prepare_dataset -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process -from axolotl.utils.models import load_tokenizer -from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.wandb_ import setup_wandb_env_vars - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = logging.getLogger("axolotl.scripts") - -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - - -def print_axolotl_text_art(suffix=None): - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(" axolotl", font=font) - - if is_main_process(): - print(ascii_art) - - -def get_multi_line_input() -> Optional[str]: - print("Give me an instruction (Ctrl + D to finish): ") - instruction = "" - for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() - return instruction - - -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload() - model.to(dtype=torch.float16) - - if cfg.local_rank == 0: - LOG.info("saving merged model") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - - -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - -def do_inference( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} - - for token, symbol in default_tokens.items(): - # If the token isn't already specified in the config, add it - if not (cfg.special_tokens and token in cfg.special_tokens): - tokenizer.add_special_tokens({token: symbol}) - - prompter_module = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - - if cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id - - set_model_mem_id(model, tokenizer) - model.set_mem_cache_args( - max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None - ) - - model = model.to(cfg.device) - - while True: - print("=" * 80) - # support for multiline inputs - instruction = get_multi_line_input() - if not instruction: - return - if prompter_module: - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - print("=" * 40) - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextStreamer(tokenizer) - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - generation_config=generation_config, - streamer=streamer, - ) - print("=" * 40) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) - - -def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) - - if not yaml_files: - raise ValueError( - "No YAML config files found in the specified directory. Are you using a .yml extension?" - ) - - if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") - return yaml_files[0] - - print("Choose a YAML file:") - for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") - - chosen_file = None - while chosen_file is None: - try: - choice = int(input("Enter the number of your choice: ")) - if 1 <= choice <= len(yaml_files): - chosen_file = yaml_files[choice - 1] - else: - print("Invalid choice. Please choose a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") - - return chosen_file - - -def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: - return not any(el in list2 for el in list1) - - -def load_cfg(config: Path = Path("examples/"), **kwargs): - if Path(config).is_dir(): - config = choose_config(config) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - validate_config(cfg) - - normalize_config(cfg) - - setup_wandb_env_vars(cfg) - return cfg - - -def load_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -) -> TrainDatasetMeta: - tokenizer = load_tokenizer(cfg) - - train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - ) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def check_accelerate_default_config(): - if Path(config_args.default_yaml_config_file).exists(): - LOG.warning( - f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" - ) +LOG = logging.getLogger("axolotl.scripts.finetune") def do_cli(config: Path = Path("examples/"), **kwargs): print_axolotl_text_art() + LOG.warning( + str( + PendingDeprecationWarning( + "scripts/finetune.py will be replaced with calling axolotl.cli.train" + ) + ) + ) parsed_cfg = load_cfg(config, **kwargs) check_accelerate_default_config() parser = transformers.HfArgumentParser((TrainerCliArgs)) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py new file mode 100644 index 000000000..ff8eb3b91 --- /dev/null +++ b/src/axolotl/cli/__init__.py @@ -0,0 +1,249 @@ +"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" + +import importlib +import logging +import os +import random +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +import yaml + +# add src to the pythonpath so we don't need to pip install this +from accelerate.commands.config import config_args +from art import text2art +from transformers import GenerationConfig, TextStreamer + +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.logging_config import configure_logging +from axolotl.train import TrainDatasetMeta +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.data import prepare_dataset +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_main_process +from axolotl.utils.models import load_tokenizer +from axolotl.utils.tokenization import check_dataset_labels +from axolotl.utils.wandb_ import setup_wandb_env_vars + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + +configure_logging() +LOG = logging.getLogger("axolotl.scripts") + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + + +def print_axolotl_text_art(suffix=None): + font = "nancyj" + ascii_text = " axolotl" + if suffix: + ascii_text += f" x {suffix}" + ascii_art = text2art(" axolotl", font=font) + + if is_main_process(): + print(ascii_art) + + +def get_multi_line_input() -> Optional[str]: + print("Give me an instruction (Ctrl + D to finish): ") + instruction = "" + for line in sys.stdin: + instruction += line # pylint: disable=consider-using-join + # instruction = pathlib.Path("/proc/self/fd/0").read_text() + return instruction + + +def do_merge_lora( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload() + model.to(dtype=torch.float16) + + if cfg.local_rank == 0: + LOG.info("saving merged model") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + + +def do_inference( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + prompter = cli_args.prompter + default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + + for token, symbol in default_tokens.items(): + # If the token isn't already specified in the config, add it + if not (cfg.special_tokens and token in cfg.special_tokens): + tokenizer.add_special_tokens({token: symbol}) + + prompter_module = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + + if cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id + + set_model_mem_id(model, tokenizer) + model.set_mem_cache_args( + max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None + ) + + model = model.to(cfg.device) + + while True: + print("=" * 80) + # support for multiline inputs + instruction = get_multi_line_input() + if not instruction: + return + if prompter_module: + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + print("=" * 40) + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=1024, + temperature=0.9, + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextStreamer(tokenizer) + generated = model.generate( + inputs=batch["input_ids"].to(cfg.device), + generation_config=generation_config, + streamer=streamer, + ) + print("=" * 40) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) + + +def choose_config(path: Path): + yaml_files = list(path.glob("*.yml")) + + if not yaml_files: + raise ValueError( + "No YAML config files found in the specified directory. Are you using a .yml extension?" + ) + + if len(yaml_files) == 1: + print(f"Using default YAML file '{yaml_files[0]}'") + return yaml_files[0] + + print("Choose a YAML file:") + for idx, file in enumerate(yaml_files): + print(f"{idx + 1}. {file}") + + chosen_file = None + while chosen_file is None: + try: + choice = int(input("Enter the number of your choice: ")) + if 1 <= choice <= len(yaml_files): + chosen_file = yaml_files[choice - 1] + else: + print("Invalid choice. Please choose a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return chosen_file + + +def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: + return not any(el in list2 for el in list1) + + +def load_cfg(config: Path = Path("examples/"), **kwargs): + if Path(config).is_dir(): + config = choose_config(config) + + # load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # if there are any options passed in the cli, if it is something that seems valid from the yaml, + # then overwrite the value + cfg_keys = cfg.keys() + for k, _ in kwargs.items(): + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or not cfg.strict: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] + + validate_config(cfg) + + normalize_config(cfg) + + setup_wandb_env_vars(cfg) + return cfg + + +def load_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +) -> TrainDatasetMeta: + tokenizer = load_tokenizer(cfg) + + train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) + + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +def check_accelerate_default_config(): + if Path(config_args.default_yaml_config_file).exists(): + LOG.warning( + f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" + ) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py new file mode 100644 index 000000000..1a5a1a2be --- /dev/null +++ b/src/axolotl/cli/inference.py @@ -0,0 +1,26 @@ +""" +CLI to run inference on a trained model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.inference = True + + do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py new file mode 100644 index 000000000..473aa8260 --- /dev/null +++ b/src/axolotl/cli/merge_lora.py @@ -0,0 +1,26 @@ +""" +CLI to run merge a trained LoRA into a base model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.merge_lora = True + + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py new file mode 100644 index 000000000..ad7d9a136 --- /dev/null +++ b/src/axolotl/cli/shard.py @@ -0,0 +1,41 @@ +""" +CLI to shard a trained model into 10GiB chunks +""" +import logging +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.scripts") + + +def shard( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + LOG.debug("Re-saving model w/ sharding") + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.shard = True + + shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py new file mode 100644 index 000000000..166af2595 --- /dev/null +++ b/src/axolotl/cli/train.py @@ -0,0 +1,35 @@ +""" +CLI to run training on a model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import ( + check_accelerate_default_config, + load_cfg, + load_datasets, + print_axolotl_text_art, +) +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + check_accelerate_default_config() + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cli_args.prepare_ds_only: + return + train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + + +fire.Fire(do_cli)