diff --git a/scripts/finetune.py b/scripts/finetune.py index 73f082f21..825574a6e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,4 +1,5 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +"""Prepare and train a model on a dataset. Can also infer from a model or merge lora.""" + import logging from pathlib import Path @@ -30,19 +31,19 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cfg = load_cfg(config, **kwargs) check_accelerate_default_config() check_user_token() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) if parsed_cli_args.inference: do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) elif parsed_cli_args.merge_lora: - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + do_merge_lora(cfg=parsed_cfg) elif parsed_cli_args.shard: - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + shard(cfg=parsed_cfg) else: dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + train(cfg=parsed_cfg, dataset_meta=dataset_meta) if __name__ == "__main__": diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py index f8331cff5..90c6c044e 100644 --- a/src/axolotl/cli/art.py +++ b/src/axolotl/cli/art.py @@ -21,6 +21,7 @@ AXOLOTL_LOGO = """ def print_dep_versions(): + """Prints versions of various axolotl dependencies.""" packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] max_len = max(len(pkg) for pkg in packages) if is_main_process(): @@ -33,6 +34,7 @@ def print_dep_versions(): def print_legacy_axolotl_text_art(suffix=None): + """Prints axolotl ASCII art and dependency versions.""" font = "nancyj" ascii_text = " axolotl" if suffix: @@ -46,5 +48,6 @@ def print_legacy_axolotl_text_art(suffix=None): def print_axolotl_text_art(): + """Prints axolotl ASCII art.""" if is_main_process(): print(AXOLOTL_LOGO) diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index 1d834691e..c450a1cf6 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -15,6 +15,7 @@ LOG = logging.getLogger(__name__) def check_accelerate_default_config(): + """Logs at warning level if no accelerate config file is found.""" if Path(config_args.default_yaml_config_file).exists(): LOG.warning( f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" @@ -22,6 +23,14 @@ def check_accelerate_default_config(): def check_user_token(): + """Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1. + + Returns: + Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved). + + Raises: + LocalTokenNotFoundError: If HF user info can't be retrieved. + """ # Skip check if HF_HUB_OFFLINE is set to True if os.getenv("HF_HUB_OFFLINE") == "1": LOG.info( diff --git a/src/axolotl/cli/datasets.py b/src/axolotl/cli/datasets.py index e98321c0b..5e44e003b 100644 --- a/src/axolotl/cli/datasets.py +++ b/src/axolotl/cli/datasets.py @@ -2,8 +2,9 @@ import logging import math import random +from typing import Union -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs from axolotl.train import TrainDatasetMeta from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_dpo_datasets @@ -17,7 +18,7 @@ LOG = logging.getLogger(__name__) def load_datasets( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], ) -> TrainDatasetMeta: tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None @@ -61,7 +62,9 @@ def load_datasets( def load_rl_datasets( *, cfg: DictDefault, - cli_args: TrainerCliArgs, # pylint: disable=unused-argument + cli_args: Union[ + PreprocessCliArgs, TrainerCliArgs + ], # pylint: disable=unused-argument ) -> TrainDatasetMeta: train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) total_num_steps = int( diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 17e2e8dc3..719dc9650 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -29,7 +29,7 @@ def do_evaluate(cfg, cli_args) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + evaluate(cfg=cfg, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index be6ba8236..e950b3528 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -5,7 +5,7 @@ import logging import sys from pathlib import Path from threading import Thread -from typing import Optional, Union +from typing import Union import fire import torch @@ -15,7 +15,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer from axolotl.utils.chat_templates import ( get_chat_template, get_chat_template_from_config, @@ -25,7 +25,13 @@ from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def get_multi_line_input() -> Optional[str]: +def get_multi_line_input() -> str: + """ + Gets multi-line input from terminal. + + Returns: + Possibly multi-line, possibly empty stdin input as a string. + """ print("Give me an instruction (Ctrl + D to submit): ") instruction = "" for line in sys.stdin: @@ -37,9 +43,18 @@ def get_multi_line_input() -> Optional[str]: def do_inference( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: InferenceCliArgs, ): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + """ + Runs inference on the command line in a loop. User input is accepted, a chat template + is (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Training-specific CLI arguments. + """ + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) prompter = cli_args.prompter prompter_module = None @@ -121,11 +136,20 @@ def do_inference( def do_inference_gradio( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: InferenceCliArgs, ): + """ + Runs inference in a Gradio interface. User input is accepted, a chat template is + (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Training-specific CLI arguments. + """ import gradio as gr - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) prompter = cli_args.prompter prompter_module = None @@ -218,16 +242,25 @@ def do_inference_gradio( ) -def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): +def do_cli( + config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs +) -> None: + """ + Parses axolotl config, training-specific CLI args, and calls `do_inference` or + `do_inference_gradio` as a subroutine. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg.sample_packing = False - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(InferenceCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - parsed_cli_args.inference = True if gradio: do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 30fdb6ad7..73a291fed 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -29,7 +29,14 @@ def cli(): @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) def preprocess(config: str, **kwargs): - """Preprocess datasets before training.""" + """ + Preprocess datasets before training. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ kwargs = {k: v for k, v in kwargs.items() if v is not None} from axolotl.cli.preprocess import do_cli @@ -46,8 +53,16 @@ def preprocess(config: str, **kwargs): ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def train(config: str, accelerate: bool, **kwargs): - """Train or fine-tune a model.""" +def train(config: str, accelerate: bool, **kwargs) -> None: + """ + Train or fine-tune a model. + + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ kwargs = {k: v for k, v in kwargs.items() if v is not None} # Enable expandable segments for cuda allocation to improve VRAM usage @@ -74,8 +89,16 @@ def train(config: str, accelerate: bool, **kwargs): ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) -def evaluate(config: str, accelerate: bool, **kwargs): - """Evaluate a model.""" +def evaluate(config: str, accelerate: bool, **kwargs) -> None: + """ + Evaluate a model. + + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ kwargs = {k: v for k, v in kwargs.items() if v is not None} if accelerate: @@ -97,46 +120,32 @@ def evaluate(config: str, accelerate: bool, **kwargs): default=False, help="Use accelerate launch for multi-GPU inference", ) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing LoRA model", -) -@click.option( - "--base-model", - type=click.Path(exists=True, path_type=str), - help="Path to base model for non-LoRA models", -) @click.option("--gradio", is_flag=True, help="Launch Gradio interface") -@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def inference( - config: str, - accelerate: bool, - lora_model_dir: Optional[str] = None, - base_model: Optional[str] = None, - **kwargs, -): - """Run inference with a trained model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - del kwargs["inference"] # interferes with inference.do_cli - - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if base_model: - kwargs["base_model"] = base_model +def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: + """ + Run inference with a trained model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + gradio: Whether to use Gradio browser interface or command line for inference. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] if config: base_cmd.append(config) + if gradio: + base_cmd.append("--gradio") cmd = build_command(base_cmd, kwargs) subprocess.run(cmd, check=True) # nosec B603 else: from axolotl.cli.inference import do_cli - do_cli(config=config, **kwargs) + do_cli(config=config, gradio=gradio, **kwargs) @cli.command() @@ -146,20 +155,18 @@ def inference( default=False, help="Use accelerate launch for multi-GPU operations", ) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing model weights to shard", -) -@click.option( - "--save-dir", - type=click.Path(path_type=str), - help="Directory to save sharded weights", -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def shard(config: str, accelerate: bool, **kwargs): - """Shard model weights.""" +def shard(config: str, accelerate: bool, **kwargs) -> None: + """ + Shard model weights into chunks. + + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ kwargs = {k: v for k, v in kwargs.items() if v is not None} if accelerate: @@ -181,18 +188,18 @@ def shard(config: str, accelerate: bool, **kwargs): default=True, help="Use accelerate launch for weight merging", ) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing sharded weights", -) -@click.option( - "--save-path", type=click.Path(path_type=str), help="Path to save merged weights" -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): - """Merge sharded FSDP model weights.""" +def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: + """ + Merge sharded FSDP model weights. + + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ kwargs = {k: v for k, v in kwargs.items() if v is not None} if accelerate: @@ -214,27 +221,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing the LoRA model to merge", -) -@click.option( - "--output-dir", - type=click.Path(path_type=str), - help="Directory to save the merged model", -) -def merge_lora( - config: str, - lora_model_dir: Optional[str] = None, - output_dir: Optional[str] = None, -): - """Merge a trained LoRA into a base model""" - kwargs = {} - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if output_dir: - kwargs["output_dir"] = output_dir +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def merge_lora(config: str, **kwargs) -> None: + """ + Merge trained LoRA adapters into a base model. + + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} from axolotl.cli.merge_lora import do_cli @@ -244,13 +243,17 @@ def merge_lora( @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") -def fetch(directory: str, dest: Optional[str]): +def fetch(directory: str, dest: Optional[str]) -> None: """ Fetch example configs or other resources. Available directories: - examples: Example configuration files - deepspeed_configs: DeepSpeed configuration files + + Args: + directory: One of `examples`, `deepspeed_configs`. + dest: Optional destination directory. """ fetch_from_github(f"{directory}/", dest) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 42e315bfe..6e461bf53 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -13,18 +13,23 @@ from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.merge_lora") +LOG = logging.getLogger(__name__) -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) +def do_merge_lora(*, cfg: DictDefault) -> None: + """ + Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config + along with the LoRA adapters to combine them into a single base model. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ + print_axolotl_text_art() + + model, tokenizer = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True - LOG.info("running merge of LoRA with base model") + LOG.info("Running merge of LoRA with base model...") model = model.merge_and_unload(progressbar=True) try: model.to(dtype=cfg.torch_dtype) @@ -33,7 +38,7 @@ def do_merge_lora( model.generation_config.do_sample = True if cfg.local_rank == 0: - LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") + LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") model.save_pretrained( str(Path(cfg.output_dir) / "merged"), safe_serialization=safe_serialization, @@ -42,10 +47,21 @@ def do_merge_lora( tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine. + Note that various config values will be overwritten to allow the LoRA merge logic to work + as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Raises: + ValueError: If target directory for LoRA merged model does not exist. + """ # pylint: disable=duplicate-code - print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) @@ -76,7 +92,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parsed_cfg.fsdp = None parsed_cfg.fsdp_config = None - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + do_merge_lora(cfg=parsed_cfg) if __name__ == "__main__": diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 182abfaa4..a12be5d27 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -18,24 +18,25 @@ from axolotl.cli.config import load_cfg from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code +def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: + """ + Preprocesses dataset specified in axolotl config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Preprocessing-specific CLI arguments. + """ print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parsed_cfg.is_preprocess = True check_accelerate_default_config() check_user_token() - parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if not parsed_cfg.dataset_prepared_path: + if not cfg.dataset_prepared_path: msg = ( Fore.RED + "preprocess CLI called without dataset_prepared_path set, " @@ -43,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + Fore.RESET ) LOG.warning(msg) - parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH + cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH with disable_datasets_caching(): - if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": - load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if cfg.rl: # and cfg.rl != "orpo": + load_rl_datasets(cfg=cfg, cli_args=cli_args) else: - load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + load_datasets(cfg=cfg, cli_args=cli_args) - if parsed_cli_args.download: - model_name = parsed_cfg.base_model + if cli_args.download: + model_name = cfg.base_model with warnings.catch_warnings(): # there are a bunch of useless UserWarnings about # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" @@ -69,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.info( Fore.GREEN - + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + + f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + Fore.RESET ) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg.is_preprocess = True + parser = transformers.HfArgumentParser(PreprocessCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_preprocess(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index 86533db7e..ad680e4b3 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -5,39 +5,27 @@ from pathlib import Path from typing import Union import fire -import transformers from dotenv import load_dotenv from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.common.cli import load_model_and_tokenizer from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) +def shard(*, cfg: DictDefault): + model, _ = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True LOG.debug("Re-saving model w/ sharding") model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.shard = True - - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + shard(cfg=parsed_cfg) if __name__ == "__main__": diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index eb70cfaa1..22fa143f0 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -15,21 +15,21 @@ from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.integrations.base import PluginManager from axolotl.train import train +from axolotl.utils.dict import DictDefault LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - parsed_cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - return do_train(parsed_cfg, parsed_cli_args) +def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Trains transformer model by first loading the dataset(s) specified in the `axolotl` + config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin + manager's `post_train_unload` once training completes. - -def do_train(cfg, cli_args) -> None: + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Training-specific CLI arguments. + """ print_axolotl_text_art() check_accelerate_default_config() check_user_token() @@ -39,7 +39,7 @@ def do_train(cfg, cli_args) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) plugin_manager = PluginManager.get_instance() del model @@ -48,6 +48,24 @@ def do_train(cfg, cli_args) -> None: plugin_manager.post_train_unload(cfg) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_train(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 02ad9201b..f714d7d4b 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -1,10 +1,10 @@ -""" -shared module for cli specific things -""" +"""Shared module for CLI specific utilities.""" import logging from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional, Tuple + +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging @@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer configure_logging() -LOG = logging.getLogger("axolotl.common.cli") +LOG = logging.getLogger(__name__) @dataclass class PreprocessCliArgs: - """ - dataclass representing arguments for preprocessing only - """ + """Dataclass with CLI arguments for `axolotl preprocess` command.""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -30,14 +28,11 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ + """Dataclass with CLI arguments for `axolotl train` command.""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=0) - inference: bool = field(default=False) merge_lora: bool = field(default=False) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) @@ -45,25 +40,40 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: - """ - dataclass representing the various evaluation arguments - """ + """Dataclass with CLI arguments for `axolotl evaluate` command.""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=0) +@dataclass +class InferenceCliArgs: + """Dataclass with CLI arguments for `axolotl inference` command.""" + + prompter: Optional[str] = field(default=None) + + def load_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, -): + inference: bool = False, +) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: + """ + Helper function for loading a model and tokenizer specified in the given `axolotl` + config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + `transformers` model and tokenizer. + """ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) - LOG.info("loading model and (optionally) peft_config...") - inference = getattr(cli_args, "inference", False) + LOG.info("loading model...") model, _ = load_model(cfg, tokenizer, inference=inference) return model, tokenizer diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc..fec7e9ecb 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,7 +9,6 @@ from typing import Dict, Optional import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf @@ -62,16 +61,13 @@ def evaluate_dataset( return metrics -def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -) -> Dict[str, float]: +def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets Args: - cfg: Configuration dictionary - cli_args: Command line arguments - dataset_meta: Dataset metadata containing training and evaluation datasets + cfg: Config dictionary. + dataset_meta: Dataset metadata containing training and evaluation datasets. Returns: Tuple containing: @@ -102,9 +98,7 @@ def evaluate( # Load model LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, _ = load_model(cfg, tokenizer, processor=processor) # Set up trainer trainer = setup_trainer( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a74ecc2ec..20795f4e7 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -19,7 +19,6 @@ from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from axolotl.common.cli import TrainerCliArgs from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) @@ -39,14 +38,12 @@ src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) configure_logging() -LOG = get_logger("axolotl.train") +LOG = get_logger(__name__) @dataclass class TrainDatasetMeta: - """ - dataclass to capture the dataset specific options for training - """ + """Dataclass with fields for training and validation datasets and metadata.""" train_dataset: Dataset eval_dataset: Optional[Dataset] = None @@ -54,7 +51,7 @@ class TrainDatasetMeta: def train( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # Load tokenizer LOG.debug( @@ -93,9 +90,7 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) - model, peft_config = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, peft_config = load_model(cfg, tokenizer, processor=processor) if model.generation_config is not None: model.generation_config.do_sample = True @@ -107,9 +102,7 @@ def train( model_ref = None # explicit setting to None else: # load the model again for model_ref/baseline - model_ref, _ = load_model( - cfg, tokenizer, inference=cli_args.inference, reference_model=True - ) + model_ref, _ = load_model(cfg, tokenizer, reference_model=True) safe_serialization = cfg.save_safetensors is True