diff --git a/scripts/finetune.py b/scripts/finetune.py index 8019af8e3..201a47e14 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -4,9 +4,7 @@ import importlib import logging import os import random -import signal import sys -from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -17,17 +15,17 @@ import yaml # add src to the pythonpath so we don't need to pip install this from art import text2art -from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging +from axolotl.train import TrainDatasetMeta, train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process -from axolotl.utils.models import load_model, load_model_config, load_tokenizer +from axolotl.utils.models import load_model_config, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -40,26 +38,13 @@ LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -@dataclass -class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ - - debug: bool = field(default=False) - inference: bool = field(default=False) - merge_lora: bool = field(default=False) - prepare_ds_only: bool = field(default=False) - prompter: Optional[str] = field(default=None) - shard: bool = field(default=False) - - def print_axolotl_text_art(suffix=None): font = "nancyj" ascii_text = " axolotl" if suffix: ascii_text += f" x {suffix}" ascii_art = text2art(" axolotl", font=font) + if is_main_process(): print(ascii_art) @@ -73,9 +58,45 @@ def get_multi_line_input() -> Optional[str]: return instruction -def do_inference(cfg, model, tokenizer, prompter: Optional[str]): - if prompter == "None": - prompter = None +def do_merge_lora( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload() + model.to(dtype=torch.float16) + + if cfg.local_rank == 0: + LOG.info("saving merged model") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + + +def shard( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + LOG.debug("Re-saving model w/ sharding") + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + +def do_inference( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + prompter = cli_args.prompter default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} for token, symbol in default_tokens.items(): @@ -176,141 +197,6 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b return not any(el in list2 for el in list1) -def train( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - # load the tokenizer first - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - if not ( - cli_args.shard or cli_args.merge_lora or cli_args.inference - ): # don't need to load dataset for these - train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec - ), - tokenizer, - ) - - if cli_args.prepare_ds_only: - LOG.info("Finished preparing dataset. Exiting...") - return - - # Load the model and tokenizer - LOG.info("loading model and (optionally) peft_config...") - model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) - - safe_serialization = cfg.save_safetensors is True - - if cli_args.merge_lora and cfg.adapter is not None: - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload() - model.to(dtype=torch.float16) - - if cfg.local_rank == 0: - LOG.info("saving merged model") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - return - - if cli_args.inference: - LOG.debug("Running inference on model") - do_inference(cfg, model, tokenizer, prompter=cli_args.prompter) - return - - if cli_args.shard: - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - return - - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: - possible_checkpoints = [ - str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") - ] - if len(possible_checkpoints) > 0: - sorted_paths = sorted( - possible_checkpoints, - key=lambda path: int(path.split("-")[-1]), - ) - cfg.resume_from_checkpoint = sorted_paths[-1] - LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" - ) - resume_from_checkpoint = cfg.resume_from_checkpoint - - trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps - ) - - model.config.use_cache = False - - if torch.__version__ >= "2" and sys.platform != "win32": - LOG.info("Compiling torch model") - model = torch.compile(model) - - # go ahead and presave, so we have the adapter config available to inspect - if peft_config: - LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") - peft_config.save_pretrained(cfg.output_dir) - - # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model - if cfg.local_rank == 0: - - def terminate_handler(_, __, model): - if cfg.flash_optimum: - model = BetterTransformer.reverse(model) - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - sys.exit(0) - - signal.signal( - signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) - ) - - LOG.info("Starting trainer...") - if cfg.group_by_length: - LOG.info("hang tight... sorting dataset for group_by_length") - - if not Path(cfg.output_dir).is_dir(): - os.makedirs(cfg.output_dir, exist_ok=True) - tokenizer.save_pretrained(cfg.output_dir) - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=True, enable_mem_efficient=True - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - else: - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - - LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") - - if cfg.relora_steps: - if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): - model = model.merge_and_unload() - else: - # final model weights have already been saved by `ReLoRACallback.on_train_end` - return - - # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading - # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file - if cfg.fsdp: - trainer.save_model(cfg.output_dir) - elif cfg.local_rank == 0: - if cfg.flash_optimum: - model = BetterTransformer.reverse(model) - - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - def load_cfg(config: Path = Path("examples/"), **kwargs): if Path(config).is_dir(): config = choose_config(config) @@ -347,15 +233,50 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): return cfg -def do_train(config: Path = Path("examples/"), **kwargs): +def load_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +) -> TrainDatasetMeta: + tokenizer = load_tokenizer(cfg) + + train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) + + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + check_dataset_labels( + train_dataset.select( + [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec + ), + tokenizer, + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +def do_cli(config: Path = Path("examples/"), **kwargs): print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - train(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cli_args.inference: + do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) + elif parsed_cli_args.merge_lora: + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + elif parsed_cli_args.shard: + shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cli_args.prepare_ds_only: + return + train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) if __name__ == "__main__": - fire.Fire(do_train) + fire.Fire(do_cli) diff --git a/src/axolotl/common/__init__.py b/src/axolotl/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py new file mode 100644 index 000000000..f5bd9b037 --- /dev/null +++ b/src/axolotl/common/cli.py @@ -0,0 +1,41 @@ +""" +shared module for cli specific things +""" + +import logging +from dataclasses import dataclass, field +from typing import Optional + +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +configure_logging() +LOG = logging.getLogger("axolotl.common.cli") + + +@dataclass +class TrainerCliArgs: + """ + dataclass representing the various non-training arguments + """ + + debug: bool = field(default=False) + inference: bool = field(default=False) + merge_lora: bool = field(default=False) + prepare_ds_only: bool = field(default=False) + prompter: Optional[str] = field(default=None) + shard: bool = field(default=False) + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + LOG.info("loading model and (optionally) peft_config...") + model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + + return model, tokenizer diff --git a/src/axolotl/train.py b/src/axolotl/train.py new file mode 100644 index 000000000..51ef35903 --- /dev/null +++ b/src/axolotl/train.py @@ -0,0 +1,139 @@ +"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" + +import logging +import os +import signal +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch + +# add src to the pythonpath so we don't need to pip install this +from datasets import Dataset +from optimum.bettertransformer import BetterTransformer + +from axolotl.common.cli import TrainerCliArgs +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.trainer import setup_trainer + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + +configure_logging() +LOG = logging.getLogger("axolotl.train") + + +@dataclass +class TrainDatasetMeta: + """ + dataclass to capture the dataset specific options for training + """ + + train_dataset: Dataset + eval_dataset: Optional[Dataset] = None + total_num_steps: Optional[int] = None + + +def train( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, + dataset_meta: TrainDatasetMeta, +): + # load the tokenizer first + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + train_dataset = dataset_meta.train_dataset + eval_dataset = dataset_meta.eval_dataset + total_num_steps = dataset_meta.total_num_steps + + # Load the model and tokenizer + LOG.info("loading model and (optionally) peft_config...") + model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + + safe_serialization = cfg.save_safetensors is True + + if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: + possible_checkpoints = [ + str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") + ] + if len(possible_checkpoints) > 0: + sorted_paths = sorted( + possible_checkpoints, + key=lambda path: int(path.split("-")[-1]), + ) + cfg.resume_from_checkpoint = sorted_paths[-1] + LOG.info( + f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + ) + resume_from_checkpoint = cfg.resume_from_checkpoint + + trainer = setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + ) + + model.config.use_cache = False + + if torch.__version__ >= "2" and sys.platform != "win32": + LOG.info("Compiling torch model") + model = torch.compile(model) + + # go ahead and presave, so we have the adapter config available to inspect + if peft_config: + LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") + peft_config.save_pretrained(cfg.output_dir) + + # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model + if cfg.local_rank == 0: + + def terminate_handler(_, __, model): + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + sys.exit(0) + + signal.signal( + signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) + ) + + LOG.info("Starting trainer...") + if cfg.group_by_length: + LOG.info("hang tight... sorting dataset for group_by_length") + + if not Path(cfg.output_dir).is_dir(): + os.makedirs(cfg.output_dir, exist_ok=True) + tokenizer.save_pretrained(cfg.output_dir) + if cfg.flash_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + else: + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + + LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + + if cfg.relora_steps: + if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): + model = model.merge_and_unload() + else: + # final model weights have already been saved by `ReLoRACallback.on_train_end` + return model, tokenizer + + # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading + # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file + if cfg.fsdp: + trainer.save_model(cfg.output_dir) + elif cfg.local_rank == 0: + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) + + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + return model, tokenizer