diff --git a/scripts/finetune.py b/scripts/finetune.py deleted file mode 100644 index d5bbcaf8f..000000000 --- a/scripts/finetune.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" -import logging -from pathlib import Path - -import fire -import transformers - -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - do_inference, - do_merge_lora, - load_cfg, - load_datasets, - print_axolotl_text_art, -) -from axolotl.cli.shard import shard -from axolotl.common.cli import TrainerCliArgs -from axolotl.train import train - -LOG = logging.getLogger("axolotl.scripts.finetune") - - -def do_cli(config: Path = Path("examples/"), **kwargs): - print_axolotl_text_art() - LOG.warning( - str( - PendingDeprecationWarning( - "scripts/finetune.py will be replaced with calling axolotl.cli.train" - ) - ) - ) - parsed_cfg = load_cfg(config, **kwargs) - check_accelerate_default_config() - check_user_token() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if parsed_cli_args.inference: - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.merge_lora: - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.shard: - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..b20e4f085 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -1,568 +1,5 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +"""Axolotl CLI module initialization.""" -import importlib -import json -import logging -import math import os -import random -import sys -import tempfile -from pathlib import Path -from threading import Thread -from typing import Any, Dict, List, Optional, Union -from urllib.parse import urlparse - -import requests -import torch -import yaml - -# add src to the pythonpath so we don't need to pip install this -from accelerate.commands.config import config_args -from art import text2art -from huggingface_hub import HfApi -from huggingface_hub.utils import LocalTokenNotFoundError -from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from transformers.utils import is_torch_bf16_gpu_available -from transformers.utils.import_utils import _is_package_available - -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import ( - get_chat_template, - get_chat_template_from_config, -) -from axolotl.utils.comet_ import setup_comet_env_vars -from axolotl.utils.config import ( - normalize_cfg_datasets, - normalize_config, - prepare_plugins, - validate_config, -) -from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process -from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_processor, load_tokenizer -from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env -from axolotl.utils.wandb_ import setup_wandb_env_vars - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - -AXOLOTL_LOGO = """ - #@@ #@@ @@# @@# - @@ @@ @@ @@ =@@# @@ #@ =@@#. - @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ - #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ - @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ - @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ - =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ - =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ - @@@@ @@@@@@@@@@@@@@@@ -""" - - -def print_legacy_axolotl_text_art(suffix=None): - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(ascii_text, font=font) - - if is_main_process(): - print(ascii_art) - - print_dep_versions() - - -def print_axolotl_text_art( - **kwargs, # pylint: disable=unused-argument -): - if is_main_process(): - print(AXOLOTL_LOGO) - - -def print_dep_versions(): - packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] - max_len = max(len(pkg) for pkg in packages) - if is_main_process(): - print("*" * 40) - print("**** Axolotl Dependency Versions *****") - for pkg in packages: - pkg_version = _is_package_available(pkg, return_version=True) - print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") - print("*" * 40) - - -def check_remote_config(config: Union[str, Path]): - # Check if the config is a valid HTTPS URL to a .yml or .yaml file - if not (isinstance(config, str) and config.startswith("https://")): - return config # Return the original value if it's not a valid URL - - filename = os.path.basename(urlparse(config).path) - temp_dir = tempfile.mkdtemp() - - try: - response = requests.get(config, timeout=30) - response.raise_for_status() # Check for HTTP errors - - content = response.content - try: - # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML - json.loads(content) - # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link - LOG.warning( - f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." - ) - except json.JSONDecodeError: - # If it's not valid JSON, verify it's valid YAML - try: - yaml.safe_load(content) - except yaml.YAMLError as err: - raise ValueError( - f"Failed to parse the content at {config} as YAML: {err}" - ) from err - - # Write the content to a file if it's valid YAML (or JSON treated as YAML) - output_path = Path(temp_dir) / filename - with open(output_path, "wb") as file: - file.write(content) - LOG.info( - f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" - ) - return output_path - - except requests.RequestException as err: - # This catches all requests-related exceptions including HTTPError - raise RuntimeError(f"Failed to download {config}: {err}") from err - except Exception as err: - # Catch-all for any other exceptions - raise err - - -def get_multi_line_input() -> Optional[str]: - print("Give me an instruction (Ctrl + D to submit): ") - instruction = "" - for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() - return instruction - - -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload(progressbar=True) - try: - model.to(dtype=cfg.torch_dtype) - except RuntimeError: - pass - model.generation_config.do_sample = True - - if cfg.local_rank == 0: - LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - progressbar=True, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - - -def do_inference( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": - chat_template_str = get_chat_template_from_config( - cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer - ) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - while True: - print("=" * 80) - # support for multiline inputs - instruction = get_multi_line_input() - if not instruction: - return - - if prompter_module: - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - print("=" * 40) - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextStreamer(tokenizer) - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - generation_config=generation_config, - streamer=streamer, - ) - print("=" * 40) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) - - -def do_inference_gradio( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - import gradio as gr - - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - def generate(instruction): - if not instruction: - return - if prompter_module: - # pylint: disable=stop-iteration-return - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), - temperature=cfg.get("gradio_temperature", 0.9), - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextIteratorStreamer(tokenizer) - generation_kwargs = { - "inputs": batch["input_ids"].to(cfg.device), - "attention_mask": batch["attention_mask"].to(cfg.device), - "generation_config": generation_config, - "streamer": streamer, - } - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - - all_text = "" - - for new_text in streamer: - all_text += new_text - yield all_text - - demo = gr.Interface( - fn=generate, - inputs="textbox", - outputs="text", - title=cfg.get("gradio_title", "Axolotl Gradio Interface"), - ) - - demo.queue().launch( - show_api=False, - share=cfg.get("gradio_share", True), - server_name=cfg.get("gradio_server_name", "127.0.0.1"), - server_port=cfg.get("gradio_server_port", None), - ) - - -def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) - - if not yaml_files: - raise ValueError( - "No YAML config files found in the specified directory. Are you using a .yml extension?" - ) - - if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") - return str(yaml_files[0]) - - print("Choose a YAML file:") - for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") - - chosen_file = None - while chosen_file is None: - try: - choice = int(input("Enter the number of your choice: ")) - if 1 <= choice <= len(yaml_files): - chosen_file = str(yaml_files[choice - 1]) - else: - print("Invalid choice. Please choose a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") - - return chosen_file - - -def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: - return not any(el in list2 for el in list1) - - -def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): - config = check_remote_config(config) - if Path(config).is_dir(): - config = choose_config(Path(config)) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - cfg.axolotl_config_path = config - - try: - device_props = torch.cuda.get_device_properties("cuda") - gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) - except: # pylint: disable=bare-except # noqa: E722 - gpu_version = None - - prepare_plugins(cfg) - - cfg = validate_config( - cfg, - capabilities={ - "bf16": is_torch_bf16_gpu_available(), - "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), - "compute_capability": gpu_version, - }, - env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], - }, - ) - - prepare_optim_env(cfg) - - prepare_opinionated_env(cfg) - - normalize_config(cfg) - - normalize_cfg_datasets(cfg) - - setup_wandb_env_vars(cfg) - - setup_mlflow_env_vars(cfg) - - setup_comet_env_vars(cfg) - - return cfg - - -def load_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -) -> TrainDatasetMeta: - tokenizer = load_tokenizer(cfg) - processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, - tokenizer, - processor=processor, - ) - - if ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ): - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - ) - - LOG.info("printing prompters...") - for prompter in prompters: - LOG.info(prompter) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def load_rl_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, # pylint: disable=unused-argument -) -> TrainDatasetMeta: - train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - - tokenizer = load_tokenizer(cfg) - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - rl_mode=True, - ) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def check_accelerate_default_config(): - if Path(config_args.default_yaml_config_file).exists(): - LOG.warning( - f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" - ) - - -def check_user_token(): - # Skip check if HF_HUB_OFFLINE is set to True - if os.getenv("HF_HUB_OFFLINE") == "1": - LOG.info( - "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." - ) - return True - - # Verify if token is valid - api = HfApi() - try: - user_info = api.whoami() - return bool(user_info) - except LocalTokenNotFoundError: - LOG.warning( - "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." - ) - return False diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py new file mode 100644 index 000000000..0618e07f1 --- /dev/null +++ b/src/axolotl/cli/args.py @@ -0,0 +1,43 @@ +"""Module for axolotl CLI command arguments.""" + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class PreprocessCliArgs: + """Dataclass with CLI arguments for `axolotl preprocess` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=1) + prompter: Optional[str] = field(default=None) + download: Optional[bool] = field(default=True) + + +@dataclass +class TrainerCliArgs: + """Dataclass with CLI arguments for `axolotl train` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + merge_lora: bool = field(default=False) + prompter: Optional[str] = field(default=None) + shard: bool = field(default=False) + + +@dataclass +class EvaluateCliArgs: + """Dataclass with CLI arguments for `axolotl evaluate` command.""" + + debug: bool = field(default=False) + debug_text_only: bool = field(default=False) + debug_num_examples: int = field(default=0) + + +@dataclass +class InferenceCliArgs: + """Dataclass with CLI arguments for `axolotl inference` command.""" + + prompter: Optional[str] = field(default=None) diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py new file mode 100644 index 000000000..6ed22a52d --- /dev/null +++ b/src/axolotl/cli/art.py @@ -0,0 +1,23 @@ +"""Axolotl ASCII logo utils.""" + +from axolotl.utils.distributed import is_main_process + +AXOLOTL_LOGO = """ + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ +""" + + +def print_axolotl_text_art(): + """Prints axolotl ASCII art.""" + if is_main_process(): + print(AXOLOTL_LOGO) diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py new file mode 100644 index 000000000..cc3ed0d9f --- /dev/null +++ b/src/axolotl/cli/checks.py @@ -0,0 +1,50 @@ +"""Various checks for Axolotl CLI.""" + +import logging +import os +from pathlib import Path + +from accelerate.commands.config import config_args +from huggingface_hub import HfApi +from huggingface_hub.utils import LocalTokenNotFoundError + +from axolotl.logging_config import configure_logging + +configure_logging() +LOG = logging.getLogger(__name__) + + +def check_accelerate_default_config() -> None: + """Logs at warning level if no accelerate config file is found.""" + if Path(config_args.default_yaml_config_file).exists(): + LOG.warning( + f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" + ) + + +def check_user_token() -> bool: + """Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1. + + Returns: + Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved). + + Raises: + LocalTokenNotFoundError: If HF user info can't be retrieved. + """ + # Skip check if HF_HUB_OFFLINE is set to True + if os.getenv("HF_HUB_OFFLINE") == "1": + LOG.info( + "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." + ) + return True + + # Verify if token is valid + api = HfApi() + try: + user_info = api.whoami() + return bool(user_info) + except LocalTokenNotFoundError: + LOG.warning( + "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." + ) + return False diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py new file mode 100644 index 000000000..166a67670 --- /dev/null +++ b/src/axolotl/cli/config.py @@ -0,0 +1,217 @@ +"""Configuration loading and processing.""" + +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Union +from urllib.parse import urlparse + +import requests +import torch +import yaml +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.integrations.base import PluginManager +from axolotl.utils.comet_ import setup_comet_env_vars +from axolotl.utils.config import ( + normalize_cfg_datasets, + normalize_config, + validate_config, +) +from axolotl.utils.dict import DictDefault +from axolotl.utils.mlflow_ import setup_mlflow_env_vars +from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env +from axolotl.utils.wandb_ import setup_wandb_env_vars + +LOG = logging.getLogger(__name__) + + +def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: + """ + First, determines if the passed config is a valid HTTPS URL. Then, attempts to query + for it and parse its content, first as JSON, then as YAML (YAML is preferred). + Finally, the parsed content is written to a local file and its path is returned. + + Args: + config: HTTPS URL to a YAML or JSON file. + + Returns: + Either the original `config` if it's not a valid HTTPS URL, or the path to the + downloaded remote config. + + Raises: + ValueError: If the remote configuration is neither valid JSON or YAML. + RuntimeError: If some request-related exception occurs from the file download. + Exception: Catch-all for any other exception. + """ + # Check if the config is a valid HTTPS URL to a .yml or .yaml file + if not (isinstance(config, str) and config.startswith("https://")): + return config # Return the original value if it's not a valid URL + + filename = os.path.basename(urlparse(config).path) + temp_dir = tempfile.mkdtemp() + + try: + response = requests.get(config, timeout=30) + response.raise_for_status() # Check for HTTP errors + + content = response.content + try: + # Try parsing as JSON first to catch cases where JSON content is mistakenly + # considered YAML. + json.loads(content) + + # Log a warning but do not raise an error; JSON is technically valid YAML. + # This can happen when you forget to point to a raw GitHub link. + LOG.warning( + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." + ) + except json.JSONDecodeError: + # If it's not valid JSON, verify it's valid YAML + try: + yaml.safe_load(content) + except yaml.YAMLError as err: + raise ValueError( + f"Failed to parse the content at {config} as YAML: {err}" + ) from err + + # Write the content to a file if it's valid YAML (or JSON treated as YAML) + output_path = Path(temp_dir) / filename + with open(output_path, "wb") as file: + file.write(content) + LOG.info( + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" + ) + return output_path + + except requests.RequestException as err: + # This catches all requests-related exceptions including HTTPError + raise RuntimeError(f"Failed to download {config}: {err}") from err + except Exception as err: + # Catch-all for any other exceptions + raise err + + +def choose_config(path: Path) -> str: + """ + Helper method for choosing a `axolotl` config YAML file (considering only files + ending with `.yml` or `.yaml`). If more than one config file exists in the passed + `path`, the user is prompted to choose one. + + Args: + path: Directory in which config file(s) are stored. + + Returns: + Path to either (1) the sole YAML file, or (2) if more than one YAML files exist, + the user-selected YAML file. + + Raises: + ValueError: If no YAML files are found in the given `path`. + """ + yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml")) + + if not yaml_files: + raise ValueError( + "No YAML config files found in the specified directory. Are you using a .yml extension?" + ) + + if len(yaml_files) == 1: + print(f"Using default YAML file '{yaml_files[0]}'") + return str(yaml_files[0]) + + print("Choose a YAML file:") + for idx, file in enumerate(yaml_files): + print(f"{idx + 1}. {file}") + + chosen_file = None + while chosen_file is None: + try: + choice = int(input("Enter the number of your choice: ")) + if 1 <= choice <= len(yaml_files): + chosen_file = str(yaml_files[choice - 1]) + else: + print("Invalid choice. Please choose a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return chosen_file + + +def prepare_plugins(cfg: DictDefault): + """ + Registers the plugins for the given configuration. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) + + +def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: + """ + Loads the `axolotl` configuration stored at `config`, validates it, and performs + various setup. + + Args: + config: Path (local or remote) to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Returns: + `DictDefault` mapping configuration keys to values. + """ + config = check_remote_config(config) + if Path(config).is_dir(): + config = choose_config(Path(config)) + + # Load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + + # If there are any options passed in the cli, if it is something that seems valid + # from the yaml, then overwrite the value + cfg_keys = cfg.keys() + for k, _ in kwargs.items(): + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or not cfg.strict: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] + + cfg.axolotl_config_path = config + + try: + device_props = torch.cuda.get_device_properties("cuda") + gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) + except: # pylint: disable=bare-except # noqa: E722 + gpu_version = None + + prepare_plugins(cfg) + + cfg = validate_config( + cfg, + capabilities={ + "bf16": is_torch_bf16_gpu_available(), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), + "compute_capability": gpu_version, + }, + env_capabilities={ + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + }, + ) + + prepare_optim_env(cfg) + prepare_opinionated_env(cfg) + normalize_config(cfg) + normalize_cfg_datasets(cfg) + setup_wandb_env_vars(cfg) + setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + + return cfg diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b..c89715719 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run evaluation on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,35 +8,48 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate +from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.evaluate") +LOG = logging.getLogger(__name__) -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Evaluates a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes + evaluation metrics on the given dataset(s) and writes them to disk. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: CLI arguments. + """ # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + evaluate(cfg=cfg, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_evaluate`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index a5f1a8ad8..e11a39bd6 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,32 +1,267 @@ -""" -CLI to run inference on a trained model -""" +"""CLI to run inference on a trained model.""" + +import importlib +import logging +import sys from pathlib import Path +from threading import Thread from typing import Union import fire +import torch import transformers from dotenv import load_dotenv +from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from axolotl.cli import ( - do_inference, - do_inference_gradio, - load_cfg, - print_axolotl_text_art, +from axolotl.cli.args import InferenceCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.utils.chat_templates import ( + get_chat_template, + get_chat_template_from_config, ) -from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): +def get_multi_line_input() -> str: + """ + Gets multi-line input from terminal. + + Returns: + Possibly multi-line, possibly empty stdin input as a string. + """ + print("Give me an instruction (Ctrl + D to submit): ") + + instruction = "" + for line in sys.stdin: + instruction += line # pylint: disable=consider-using-join + + return instruction + + +def do_inference( + *, + cfg: DictDefault, + cli_args: InferenceCliArgs, +): + """ + Runs inference on the command line in a loop. User input is accepted, a chat template + is (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Inference-specific CLI arguments. + """ + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template) + elif cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer + ) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + while True: + print("=" * 80) + # support for multiline inputs + instruction = get_multi_line_input() + if not instruction: + return + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + print("=" * 40) + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=1024, + temperature=0.9, + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextStreamer(tokenizer) + generated = model.generate( + inputs=batch["input_ids"].to(cfg.device), + generation_config=generation_config, + streamer=streamer, + ) + print("=" * 40) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) + + +def do_inference_gradio( + *, + cfg: DictDefault, + cli_args: InferenceCliArgs, +): + """ + Runs inference in a Gradio interface. User input is accepted, a chat template is + (optionally) applied, and the model specified in the `axolotl` config is used to + generate completions according to a default generation config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Inference-specific CLI arguments. + """ + import gradio as gr + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + def generate(instruction): + if not instruction: + return + if prompter_module: + # pylint: disable=stop-iteration-return + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), + temperature=cfg.get("gradio_temperature", 0.9), + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = { + "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), + "generation_config": generation_config, + "streamer": streamer, + } + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + all_text = "" + + for new_text in streamer: + all_text += new_text + yield all_text + + demo = gr.Interface( + fn=generate, + inputs="textbox", + outputs="text", + title=cfg.get("gradio_title", "Axolotl Gradio Interface"), + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) + + +def do_cli( + config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs +) -> None: + """ + Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg.sample_packing = False - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(InferenceCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - parsed_cli_args.inference = True if gradio: do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..43e2de3db 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,18 +1,20 @@ -"""CLI definition for various axolotl commands.""" +"""Click CLI definitions for various axolotl commands.""" # pylint: disable=redefined-outer-name + import subprocess # nosec B404 from typing import Optional import click import axolotl +from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, build_command, fetch_from_github, + filter_none_kwargs, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -27,10 +29,16 @@ def cli(): @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) -def preprocess(config: str, **kwargs): - """Preprocess datasets before training.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def preprocess(config: str, **kwargs) -> None: + """ + Preprocess datasets before training. + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ from axolotl.cli.preprocess import do_cli do_cli(config=config, **kwargs) @@ -45,10 +53,17 @@ def preprocess(config: str, **kwargs): ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def train(config: str, accelerate: bool, **kwargs): - """Train or fine-tune a model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def train(config: str, accelerate: bool, **kwargs) -> None: + """ + Train or fine-tune a model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() @@ -73,10 +88,17 @@ def train(config: str, accelerate: bool, **kwargs): ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) -def evaluate(config: str, accelerate: bool, **kwargs): - """Evaluate a model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def evaluate(config: str, accelerate: bool, **kwargs) -> None: + """ + Evaluate a model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] if config: @@ -96,81 +118,33 @@ def evaluate(config: str, accelerate: bool, **kwargs): default=False, help="Use accelerate launch for multi-GPU inference", ) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing LoRA model", -) -@click.option( - "--base-model", - type=click.Path(exists=True, path_type=str), - help="Path to base model for non-LoRA models", -) @click.option("--gradio", is_flag=True, help="Launch Gradio interface") -@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def inference( - config: str, - accelerate: bool, - lora_model_dir: Optional[str] = None, - base_model: Optional[str] = None, - **kwargs, -): - """Run inference with a trained model.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - del kwargs["inference"] # interferes with inference.do_cli - - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if base_model: - kwargs["base_model"] = base_model +@filter_none_kwargs +def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: + """ + Run inference with a trained model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + gradio: Whether to use Gradio browser interface or command line for inference. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] if config: base_cmd.append(config) + if gradio: + base_cmd.append("--gradio") cmd = build_command(base_cmd, kwargs) subprocess.run(cmd, check=True) # nosec B603 else: from axolotl.cli.inference import do_cli - do_cli(config=config, **kwargs) - - -@cli.command() -@click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--accelerate/--no-accelerate", - default=False, - help="Use accelerate launch for multi-GPU operations", -) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing model weights to shard", -) -@click.option( - "--save-dir", - type=click.Path(path_type=str), - help="Directory to save sharded weights", -) -@add_options_from_dataclass(TrainerCliArgs) -@add_options_from_config(AxolotlInputConfig) -def shard(config: str, accelerate: bool, **kwargs): - """Shard model weights.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 - else: - from axolotl.cli.shard import do_cli - - do_cli(config=config, **kwargs) + do_cli(config=config, gradio=gradio, **kwargs) @cli.command() @@ -180,20 +154,19 @@ def shard(config: str, accelerate: bool, **kwargs): default=True, help="Use accelerate launch for weight merging", ) -@click.option( - "--model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing sharded weights", -) -@click.option( - "--save-path", type=click.Path(path_type=str), help="Path to save merged weights" -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) -def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): - """Merge sharded FSDP model weights.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} +@filter_none_kwargs +def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: + """ + Merge sharded FSDP model weights. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ if accelerate: base_cmd = [ "accelerate", @@ -213,28 +186,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--lora-model-dir", - type=click.Path(exists=True, path_type=str), - help="Directory containing the LoRA model to merge", -) -@click.option( - "--output-dir", - type=click.Path(path_type=str), - help="Directory to save the merged model", -) -def merge_lora( - config: str, - lora_model_dir: Optional[str] = None, - output_dir: Optional[str] = None, -): - """Merge a trained LoRA into a base model""" - kwargs = {} - if lora_model_dir: - kwargs["lora_model_dir"] = lora_model_dir - if output_dir: - kwargs["output_dir"] = output_dir +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs +def merge_lora(config: str, **kwargs) -> None: + """ + Merge trained LoRA adapters into a base model. + Args: + config: Path to `axolotl` config YAML file. + accelerate: Whether to use `accelerate` launcher. + kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` + config options. + """ from axolotl.cli.merge_lora import do_cli do_cli(config=config, **kwargs) @@ -243,13 +207,17 @@ def merge_lora( @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") -def fetch(directory: str, dest: Optional[str]): +def fetch(directory: str, dest: Optional[str]) -> None: """ Fetch example configs or other resources. Available directories: - examples: Example configuration files - deepspeed_configs: DeepSpeed configuration files + + Args: + directory: One of `examples`, `deepspeed_configs`. + dest: Optional destination directory. """ fetch_from_github(f"{directory}/", dest) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 8c321bc48..595eb3eab 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,6 @@ -""" -CLI to run merge a trained LoRA into a base model -""" +"""CLI to merge a trained LoRA into a base model.""" + +import logging from pathlib import Path from typing import Union @@ -8,14 +8,58 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code +def do_merge_lora(*, cfg: DictDefault) -> None: + """ + Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config + along with the LoRA adapters to combine them into a single base model. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + """ print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + + model, tokenizer = load_model_and_tokenizer(cfg=cfg) + safe_serialization = cfg.save_safetensors is True + + LOG.info("Running merge of LoRA with base model...") + model = model.merge_and_unload(progressbar=True) + model.to(dtype=cfg.torch_dtype) + model.generation_config.do_sample = True + + if cfg.local_rank == 0: + LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + progressbar=True, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various + config values will be overwritten to allow the LoRA merge logic to work as expected + (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + + Raises: + ValueError: If target directory for LoRA merged model does not exist. + """ + # pylint: disable=duplicate-code + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) @@ -46,7 +90,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parsed_cfg.fsdp = None parsed_cfg.fsdp_config = None - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + do_merge_lora(cfg=parsed_cfg) if __name__ == "__main__": diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 6be9af1f7..d4b36d92c 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -1,6 +1,5 @@ -""" -This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint -""" +"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" + import json import logging import os @@ -25,16 +24,15 @@ from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.config import load_cfg -LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") +LOG = logging.getLogger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): - """ - A custom planner to cast tensors to bfloat16 on the fly during loading. - """ + """A custom planner to cast tensors to bfloat16 on the fly during loading.""" def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument tensor.copy_(tensor.to(torch.bfloat16)) @@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights( save_path: str, safe_serialization: bool = False, max_shard_size: str = "5GB", -): +) -> Path: """ - Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save` + Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will + save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. - Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`. + Args: + checkpoint_dir: Directory where distributed checkpoint is saved. + save_path: Path to save model to. + safe_serialization: Whether to save in safetensors format. + max_shard_size: Max size of model shards to save. + + Returns: + Path where model is saved. """ state_dict: Dict = {} @@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights( state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size ) + # Save index if sharded index = None if state_dict_split.is_sharded: @@ -135,6 +142,9 @@ def merge_fsdp_weights( Whether to save the merged weights with safetensors (recommended). remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): Whether to remove the checkpoint directory after merging. + + Raises: + ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist. """ checkpoint_dir_ = Path(checkpoint_dir) from accelerate.state import PartialState @@ -178,18 +188,21 @@ def merge_fsdp_weights( def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + """ + Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ # pylint: disable=duplicate-code print_axolotl_text_art() - parser = transformers.HfArgumentParser((TrainerCliArgs)) + parser = transformers.HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) parsed_cli_args.merge_lora = True - - parsed_cfg = load_cfg( - config, - **kwargs, - ) + parsed_cfg = load_cfg(config, **kwargs) fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" merge_fsdp_weights( diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1592aa78..760fe76fa 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run preprocessing of a dataset.""" + import logging import warnings from pathlib import Path @@ -13,34 +12,31 @@ from colorama import Fore from dotenv import load_dotenv from transformers import AutoModelForCausalLM -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import PreprocessCliArgs +from axolotl.cli.args import PreprocessCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.common.datasets import load_datasets, load_preference_datasets +from axolotl.utils.dict import DictDefault from axolotl.utils.trainer import disable_datasets_caching -LOG = logging.getLogger("axolotl.cli.preprocess") +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code +def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: + """ + Preprocesses dataset specified in axolotl config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Preprocessing-specific CLI arguments. + """ print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parsed_cfg.is_preprocess = True check_accelerate_default_config() check_user_token() - parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if not parsed_cfg.dataset_prepared_path: + if not cfg.dataset_prepared_path: msg = ( Fore.RED + "preprocess CLI called without dataset_prepared_path set, " @@ -48,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + Fore.RESET ) LOG.warning(msg) - parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH + cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH with disable_datasets_caching(): - if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": - load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if cfg.rl: + load_preference_datasets(cfg=cfg, cli_args=cli_args) else: - load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + load_datasets(cfg=cfg, cli_args=cli_args) - if parsed_cli_args.download: - model_name = parsed_cfg.base_model + if cli_args.download: + model_name = cfg.base_model with warnings.catch_warnings(): # there are a bunch of useless UserWarnings about # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" @@ -74,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.info( Fore.GREEN - + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + + f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + Fore.RESET ) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_preprocess`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg.is_preprocess = True + parser = transformers.HfArgumentParser(PreprocessCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_preprocess(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py deleted file mode 100644 index 196c0e99a..000000000 --- a/src/axolotl/cli/shard.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -CLI to shard a trained model into 10GiB chunks -""" -import logging -from pathlib import Path -from typing import Union - -import fire -import transformers -from dotenv import load_dotenv - -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger("axolotl.scripts") - - -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - parsed_cli_args.shard = True - - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - - -if __name__ == "__main__": - load_dotenv() - fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2a40e854e..9e3ae1cc3 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run training on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,42 +8,38 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train +from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.train") +LOG = logging.getLogger(__name__) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - # pylint: disable=duplicate-code - parsed_cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - return do_train(parsed_cfg, parsed_cli_args) +def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: + """ + Trains a `transformers` model by first loading the dataset(s) specified in the + `axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin + manager's `post_train_unload` once training completes. - -def do_train(cfg, cli_args) -> None: + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Training-specific CLI arguments. + """ print_axolotl_text_art() check_accelerate_default_config() check_user_token() - if cfg.rl: # and cfg.rl != "orpo": - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) plugin_manager = PluginManager.get_instance() del model @@ -53,6 +48,24 @@ def do_train(cfg, cli_args) -> None: plugin_manager.post_train_unload(cfg) +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: + """ + Parses `axolotl` config, CLI args, and calls `do_train`. + + Args: + config: Path to `axolotl` config YAML file. + kwargs: Additional keyword arguments to override config file values. + """ + # pylint: disable=duplicate-code + parsed_cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + do_train(parsed_cfg, parsed_cli_args) + + if __name__ == "__main__": load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 85d241b5d..addfa0ab9 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -1,32 +1,85 @@ -"""Utility methods for axoltl CLI.""" +"""Utility methods for axolotl CLI.""" + import concurrent.futures import dataclasses import hashlib import json import logging +import typing +from functools import wraps from pathlib import Path from types import NoneType -from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin +from typing import Any, Callable, Type, Union, get_args, get_origin import click import requests from pydantic import BaseModel +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast -LOG = logging.getLogger("axolotl.cli.utils") +from axolotl.logging_config import configure_logging +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + +configure_logging() +LOG = logging.getLogger(__name__) -def add_options_from_dataclass(config_class: Type[Any]): - """Create Click options from the fields of a dataclass.""" +def strip_optional_type(field_type: type | typing._SpecialForm | None): + """ + Extracts the non-`None` type from an `Optional` / `Union` type. - def decorator(function): + Args: + field_type: Type of field for Axolotl CLI command. + + Returns: + If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise + returns the input type unchanged. + """ + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + return field_type + + +def filter_none_kwargs(func: Callable) -> Callable: + """ + Wraps function to remove `None`-valued `kwargs`. + + Args: + func: Function to wrap. + + Returns: + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + """Filters out `None`-valued `kwargs`.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return func(*args, **filtered_kwargs) + + return wrapper + + +def add_options_from_dataclass(config_class: Type[Any]) -> Callable: + """ + Create Click options from the fields of a dataclass. + + Args: + config_class: Dataclass with fields to parse from the CLI. + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): - field_type = field.type + field_type = strip_optional_type(field.type) - if get_origin(field_type) is Union and type(None) in get_args(field_type): - field_type = next( - t for t in get_args(field_type) if not isinstance(t, NoneType) - ) if field_type == bool: field_name = field.name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" @@ -43,18 +96,29 @@ def add_options_from_dataclass(config_class: Type[Any]): default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator -def add_options_from_config(config_class: Type[BaseModel]): - """Create Click options from the fields of a Pydantic model.""" +def add_options_from_config(config_class: Type[BaseModel]) -> Callable: + """ + Create Click options from the fields of a Pydantic model. - def decorator(function): + Args: + config_class: PyDantic model with fields to parse from the CLI + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = strip_optional_type(field.annotation) + + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -65,13 +129,23 @@ def add_options_from_config(config_class: Type[BaseModel]): function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator -def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: - """Build command list from base command and options.""" +def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: + """ + Build command list from base command and options. + + Args: + base_cmd: Command without options. + options: Options to parse and append to base command. + + Returns: + List of strings giving shell command. + """ cmd = base_cmd.copy() for key, value in options.items(): @@ -91,18 +165,18 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: def download_file( file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str -) -> Tuple[str, str]: +) -> tuple[str, str]: """ Download a single file and return its processing status. Args: - file_info: Tuple of (file_path, remote_sha) - raw_base_url: Base URL for raw GitHub content - dest_path: Local destination directory - dir_prefix: Directory prefix to filter files + file_info: Tuple of (file_path, remote_sha). + raw_base_url: Base URL for raw GitHub content. + dest_path: Local destination directory. + dir_prefix: Directory prefix to filter files. Returns: - Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged' + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. """ file_path, remote_sha = file_info raw_url = f"{raw_base_url}/{file_path}" @@ -144,16 +218,17 @@ def download_file( def fetch_from_github( - dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5 + dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 ) -> None: """ Sync files from a specific directory in the GitHub repository. Only downloads files that don't exist locally or have changed. Args: - dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/') - dest_dir: Local destination directory - max_workers: Maximum number of concurrent downloads + dir_prefix: Directory prefix to filter files (e.g., 'examples/', + 'deepspeed_configs/'). + dest_dir: Local destination directory. + max_workers: Maximum number of concurrent downloads. """ api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" @@ -178,7 +253,7 @@ def fetch_from_github( dest_path = Path(dest_dir) if dest_dir else default_dest # Keep track of processed files for summary - files_processed: Dict[str, List[str]] = { + files_processed: dict[str, list[str]] = { "new": [], "updated": [], "unchanged": [], @@ -215,3 +290,28 @@ def fetch_from_github( LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") if files_processed["error"]: LOG.info(f"Failed files: {len(files_processed['error'])}") + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + inference: bool = False, +) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]: + """ + Helper function for loading a model and tokenizer specified in the given `axolotl` + config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + `transformers` model and tokenizer. + """ + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + LOG.info("loading model...") + model, _ = load_model(cfg, tokenizer, inference=inference) + + return model, tokenizer diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py deleted file mode 100644 index 02ad9201b..000000000 --- a/src/axolotl/common/cli.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -shared module for cli specific things -""" - -import logging -from dataclasses import dataclass, field -from typing import Optional - -import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 -from axolotl.logging_config import configure_logging -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer - -configure_logging() -LOG = logging.getLogger("axolotl.common.cli") - - -@dataclass -class PreprocessCliArgs: - """ - dataclass representing arguments for preprocessing only - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=1) - prompter: Optional[str] = field(default=None) - download: Optional[bool] = field(default=True) - - -@dataclass -class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - inference: bool = field(default=False) - merge_lora: bool = field(default=False) - prompter: Optional[str] = field(default=None) - shard: bool = field(default=False) - - -@dataclass -class EvaluateCliArgs: - """ - dataclass representing the various evaluation arguments - """ - - debug: bool = field(default=False) - debug_text_only: bool = field(default=False) - debug_num_examples: int = field(default=0) - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - LOG.info("loading model and (optionally) peft_config...") - inference = getattr(cli_args, "inference", False) - model, _ = load_model(cfg, tokenizer, inference=inference) - - return model, tokenizer diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py new file mode 100644 index 000000000..d07add29b --- /dev/null +++ b/src/axolotl/common/datasets.py @@ -0,0 +1,140 @@ +"""Dataset loading utilities.""" + +import logging +import math +import random +from dataclasses import dataclass +from typing import Optional, Union + +from datasets import Dataset + +import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 +from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs +from axolotl.utils.data import prepare_dataset +from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_processor, load_tokenizer +from axolotl.utils.tokenization import check_dataset_labels + +LOG = logging.getLogger(__name__) + + +@dataclass +class TrainDatasetMeta: + """Dataclass with fields for training and validation datasets and metadata.""" + + train_dataset: Dataset + eval_dataset: Optional[Dataset] = None + total_num_steps: Optional[int] = None + + +def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset: + """ + Randomly sample `num_samples` samples from `dataset`. + + Args: + dataset: Dataset. + num_samples: Number of samples to return. + + Returns: + Random sample (with replacement) of examples in `dataset`. + """ + return dataset.select( + [random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec + ) + + +def load_datasets( + *, + cfg: DictDefault, + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], +) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets, calling + `axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ + tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None + + train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + cfg, + tokenizer, + processor=processor, + ) + + if ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 + ): + LOG.info("check_dataset_labels...") + + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + check_dataset_labels( + train_samples, + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + ) + + LOG.info("printing prompters...") + for prompter in prompters: + LOG.info(prompter) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +def load_preference_datasets( + *, + cfg: DictDefault, + cli_args: Union[PreprocessCliArgs, TrainerCliArgs], +) -> TrainDatasetMeta: + """ + Loads one or more training or evaluation datasets for DPO training, calling + `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug + information. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + cli_args: Command-specific CLI arguments. + + Returns: + Dataclass with fields for training and evaluation datasets and the computed + `total_num_steps`. + """ + train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples) + check_dataset_labels( + train_samples, + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc..8d9ddc6ab 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,7 +9,6 @@ from typing import Dict, Optional import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf @@ -62,16 +61,13 @@ def evaluate_dataset( return metrics -def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta -) -> Dict[str, float]: +def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ Evaluate a model on training and validation datasets Args: - cfg: Configuration dictionary - cli_args: Command line arguments - dataset_meta: Dataset metadata containing training and evaluation datasets + cfg: Dictionary mapping `axolotl` config keys to values. + dataset_meta: Dataset metadata containing training and evaluation datasets. Returns: Tuple containing: @@ -102,9 +98,7 @@ def evaluate( # Load model LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, _ = load_model(cfg, tokenizer, processor=processor) # Set up trainer trainer = setup_trainer( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index a74ecc2ec..b901c2a97 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -5,21 +5,19 @@ import os import signal import sys import weakref -from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Tuple, Union import torch import transformers.modelcard from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model -from datasets import Dataset from peft import PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) @@ -39,22 +37,11 @@ src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) configure_logging() -LOG = get_logger("axolotl.train") - - -@dataclass -class TrainDatasetMeta: - """ - dataclass to capture the dataset specific options for training - """ - - train_dataset: Dataset - eval_dataset: Optional[Dataset] = None - total_num_steps: Optional[int] = None +LOG = get_logger(__name__) def train( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: # Load tokenizer LOG.debug( @@ -93,9 +80,7 @@ def train( if cfg.adapter: msg += " and peft_config..." LOG.debug(msg) - model, peft_config = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) + model, peft_config = load_model(cfg, tokenizer, processor=processor) if model.generation_config is not None: model.generation_config.do_sample = True @@ -107,9 +92,7 @@ def train( model_ref = None # explicit setting to None else: # load the model again for model_ref/baseline - model_ref, _ = load_model( - cfg, tokenizer, inference=cli_args.inference, reference_model=True - ) + model_ref, _ = load_model(cfg, tokenizer, reference_model=True) safe_serialization = cfg.save_safetensors is True diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index aff047675..de373c06e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -109,7 +109,9 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) - iter_ds = load_dataset(path, streaming=True, split=split, name=name, data_files=data_files) + iter_ds = load_dataset( + path, streaming=True, split=split, name=name, data_files=data_files + ) if skip: LOG.info(f"Skipping {skip} samples from the dataset") iter_ds = iter_ds.skip(skip) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19..d360e29d6 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b029..f06f06717 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d25..b8effa3d2 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b76..8b5fec17f 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98..aac016760 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b77..18589a80d 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): assert mock.called assert mock.call_args.kwargs["config"] == str(config_path) assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path): - """Test merge_sharded_fsdp_weights command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path): - """Test merge_sharded_fsdp_weights command with save_path option""" - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--save-path", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_path"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aa..e2dd3a6c3 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py deleted file mode 100644 index 505a2a737..000000000 --- a/tests/cli/test_cli_shard.py +++ /dev/null @@ -1,76 +0,0 @@ -"""pytest tests for axolotl CLI shard command.""" -# pylint: disable=duplicate-code -from unittest.mock import patch - -from axolotl.cli.main import cli - - -def test_shard_with_accelerate(cli_runner, config_path): - """Test shard command with accelerate""" - with patch("subprocess.run") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) - - assert mock.called - assert mock.call_args.args[0] == [ - "accelerate", - "launch", - "-m", - "axolotl.cli.shard", - str(config_path), - "--debug-num-examples", - "0", - ] - assert mock.call_args.kwargs == {"check": True} - assert result.exit_code == 0 - - -def test_shard_no_accelerate(cli_runner, config_path): - """Test shard command without accelerate""" - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"]) - - assert mock.called - assert result.exit_code == 0 - - -def test_shard_with_model_dir(cli_runner, config_path, tmp_path): - """Test shard command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - catch_exceptions=False, - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_shard_with_save_dir(cli_runner, config_path): - with patch("axolotl.cli.shard.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "shard", - str(config_path), - "--no-accelerate", - "--save-dir", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_dir"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e94..533dd5c0e 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac72..ecb0025e4 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 6562af176..291a4a4ec 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils import get_pytorch_version from axolotl.utils.config import normalize_config, prepare_plugins @@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.parametrize( @@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration: major, minor, _ = get_pytorch_version() if (major, minor) < (2, 4): with pytest.raises(ImportError): - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) else: - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 9154bf9b8..1efe889e4 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -4,8 +4,8 @@ Simple end-to-end test for Liger integration from e2e.utils import require_torch_2_4_1 -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 @@ -105,5 +105,5 @@ class LigerIntegrationTestCase: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 08b3bf0da..da27069ac 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_cli_integrations.py b/tests/e2e/patched/test_cli_integrations.py index 6ca7c52ae..ce9396d5f 100644 --- a/tests/e2e/patched/test_cli_integrations.py +++ b/tests/e2e/patched/test_cli_integrations.py @@ -5,7 +5,7 @@ from pathlib import Path import yaml -from axolotl.cli import load_cfg +from axolotl.cli.config import load_cfg from axolotl.utils.dict import DictDefault diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 791d955b2..2bfd36d15 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -8,8 +8,8 @@ import os import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -80,7 +80,7 @@ class TestFAXentropyLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index 69516810f..62ee4f717 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,5 +107,5 @@ class TestFalconPatched(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 23a0adfc0..e7ab510c9 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,5 +71,5 @@ class TestFusedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index d0fdd918a..8d0ba6c2a 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -109,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 634e544d2..bc18e3d81 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -9,8 +9,8 @@ import unittest import pytest from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -74,7 +74,7 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @@ -124,5 +124,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index e93863e09..c7fd0ecbc 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -108,5 +108,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index f87c34fd1..156dac7e8 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -64,7 +64,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -102,7 +102,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( "MixtralFlashAttention2" in model.model.layers[0].self_attn.__class__.__name__ diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 170c37fd6..78b01be64 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,7 +6,6 @@ import unittest import transformers -from axolotl.common.cli import TrainerCliArgs from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) + model, _ = load_model(cfg, tokenizer, inference=False) assert ( "MixtralFlashAttention2" @@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase): } ) normalize_config(cfg) - cli_args = TrainerCliArgs() tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=cli_args.inference) + load_model(cfg, tokenizer, inference=False) assert ( "torch.jit" diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index 852ac7bec..ce466460e 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -118,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 5639d2eae..f6a3e0109 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -9,8 +9,8 @@ import subprocess from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -71,7 +71,7 @@ class TestResumeLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) resume_cfg = cfg | DictDefault( { @@ -81,7 +81,7 @@ class TestResumeLlama: normalize_config(resume_cfg) cli_args = TrainerCliArgs() - train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=resume_cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 492bc1c23..da5eaffb6 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -6,8 +6,8 @@ import os import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -75,7 +75,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -125,7 +125,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -180,7 +180,7 @@ class TestUnslothQLoRA: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index f8109373a..2d0baceee 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -9,8 +9,8 @@ from pathlib import Path import pytest -from axolotl.cli import load_rl_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_preference_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,9 +65,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -110,9 +110,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -155,9 +155,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip("kto_pair no longer supported in trl") @@ -200,9 +200,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -244,9 +244,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @with_temp_dir @@ -291,9 +291,9 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) @pytest.mark.skip(reason="Fix the implementation") @@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase): ) normalize_config(cfg) cli_args = TrainerCliArgs() - dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 222d620ae..4261ccc26 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( @@ -104,7 +104,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) check_tensorboard( diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 117de6635..ddcb66275 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,7 +69,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -122,7 +122,7 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -161,5 +161,5 @@ class TestFalcon(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 4384bb61e..a94828490 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -7,8 +7,8 @@ import os from e2e.utils import check_model_output_exists -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -60,7 +60,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_fix_untrained_tokens(self, temp_dir): @@ -103,7 +103,7 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) def test_batch_flattening(self, temp_dir): @@ -142,5 +142,5 @@ class TestLlama: cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index d13b10659..68cd490be 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -62,5 +62,5 @@ class TestPretrainLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index 250cf418c..91f101e44 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -66,7 +66,7 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -111,5 +111,5 @@ class TestLlamaVision(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index a7ead64a5..696c47aed 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,5 @@ class TestLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index a1fc30862..4b4db3058 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -8,8 +8,8 @@ import unittest import pytest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,5 +63,5 @@ class TestMamba(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 2e79fec8d..a304e9b4a 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -67,7 +67,7 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -110,5 +110,5 @@ class TestMistral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 6792d05a6..6e06626f6 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -9,8 +9,8 @@ import unittest import torch from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -73,7 +73,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -127,7 +127,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -184,7 +184,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -241,7 +241,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -285,5 +285,5 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index f1bbaafd5..453872538 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -107,7 +107,7 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -143,5 +143,5 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index dd0af32f3..13244a215 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -8,8 +8,8 @@ import unittest from transformers.utils import is_torch_bf16_gpu_available -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -63,7 +63,7 @@ class TestPackedLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 7a08d0c6f..54f564d0e 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -65,7 +65,7 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) @with_temp_dir @@ -114,5 +114,5 @@ class TestPhi(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_relora_llama.py b/tests/e2e/test_relora_llama.py index fef6a3d30..6c785dc86 100644 --- a/tests/e2e/test_relora_llama.py +++ b/tests/e2e/test_relora_llama.py @@ -7,8 +7,8 @@ import os import unittest from pathlib import Path -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -77,7 +77,7 @@ class TestReLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) assert ( Path(temp_dir) / "checkpoint-100/relora/model.safetensors" diff --git a/tests/e2e/test_reward_model_llama.py b/tests/e2e/test_reward_model_llama.py index c4cb705ea..4cd8602f3 100644 --- a/tests/e2e/test_reward_model_llama.py +++ b/tests/e2e/test_reward_model_llama.py @@ -6,8 +6,8 @@ import logging import os import unittest -from axolotl.cli import load_datasets -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault @@ -69,5 +69,5 @@ class TestRewardModelLoraLlama(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg)