From 324c533adb9c4d873f3a546244c85e6458ed7654 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 7 Jan 2025 17:59:59 +0000 Subject: [PATCH] cleanup and (partial) docs --- scripts/finetune.py | 15 +- src/axolotl/cli/__init__.py | 535 +----------------- src/axolotl/cli/art.py | 50 ++ src/axolotl/cli/checks.py | 41 ++ src/axolotl/cli/config.py | 3 +- src/axolotl/cli/datasets.py | 2 +- src/axolotl/cli/evaluate.py | 19 +- src/axolotl/cli/inference.py | 5 +- src/axolotl/cli/main.py | 3 +- src/axolotl/cli/merge_lora.py | 7 +- src/axolotl/cli/merge_sharded_fsdp_weights.py | 9 +- src/axolotl/cli/preprocess.py | 14 +- src/axolotl/cli/shard.py | 9 +- src/axolotl/cli/train.py | 14 +- src/axolotl/cli/utils.py | 30 +- 15 files changed, 160 insertions(+), 596 deletions(-) create mode 100644 src/axolotl/cli/art.py create mode 100644 src/axolotl/cli/checks.py diff --git a/scripts/finetune.py b/scripts/finetune.py index d5bbcaf8f..73f082f21 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -5,15 +5,12 @@ from pathlib import Path import fire import transformers -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - do_inference, - do_merge_lora, - load_cfg, - load_datasets, - print_axolotl_text_art, -) +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.cli.datasets import load_datasets +from axolotl.cli.inference import do_inference +from axolotl.cli.merge_lora import do_merge_lora from axolotl.cli.shard import shard from axolotl.common.cli import TrainerCliArgs from axolotl.train import train diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 26d25258d..b20e4f085 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -1,536 +1,5 @@ -"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" -import logging +"""Axolotl CLI module initialization.""" + import os -import sys -from pathlib import Path - -# add src to the pythonpath so we don't need to pip install this -from accelerate.commands.config import config_args -from art import text2art -from huggingface_hub import HfApi -from huggingface_hub.utils import LocalTokenNotFoundError -from transformers.utils.import_utils import _is_package_available - -from axolotl.logging_config import configure_logging -from axolotl.utils.distributed import is_main_process - -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - -configure_logging() -LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - -AXOLOTL_LOGO = """ - #@@ #@@ @@# @@# - @@ @@ @@ @@ =@@# @@ #@ =@@#. - @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ - #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ - @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ - @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ - =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ - @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ - =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ - @@@@ @@@@@@@@@@@@@@@@ -""" - - -def print_legacy_axolotl_text_art(suffix=None): - font = "nancyj" - ascii_text = " axolotl" - if suffix: - ascii_text += f" x {suffix}" - ascii_art = text2art(ascii_text, font=font) - - if is_main_process(): - print(ascii_art) - - print_dep_versions() - - -def print_axolotl_text_art( - **kwargs, # pylint: disable=unused-argument -): - if is_main_process(): - print(AXOLOTL_LOGO) - - -def print_dep_versions(): - packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] - max_len = max(len(pkg) for pkg in packages) - if is_main_process(): - print("*" * 40) - print("**** Axolotl Dependency Versions *****") - for pkg in packages: - pkg_version = _is_package_available(pkg, return_version=True) - print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") - print("*" * 40) - - -<<<<<<< HEAD -def check_remote_config(config: Union[str, Path]): - # Check if the config is a valid HTTPS URL to a .yml or .yaml file - if not (isinstance(config, str) and config.startswith("https://")): - return config # Return the original value if it's not a valid URL - - filename = os.path.basename(urlparse(config).path) - temp_dir = tempfile.mkdtemp() - - try: - response = requests.get(config, timeout=30) - response.raise_for_status() # Check for HTTP errors - - content = response.content - try: - # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML - json.loads(content) - # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link - LOG.warning( - f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." - ) - except json.JSONDecodeError: - # If it's not valid JSON, verify it's valid YAML - try: - yaml.safe_load(content) - except yaml.YAMLError as err: - raise ValueError( - f"Failed to parse the content at {config} as YAML: {err}" - ) from err - - # Write the content to a file if it's valid YAML (or JSON treated as YAML) - output_path = Path(temp_dir) / filename - with open(output_path, "wb") as file: - file.write(content) - LOG.info( - f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" - ) - return output_path - - except requests.RequestException as err: - # This catches all requests-related exceptions including HTTPError - raise RuntimeError(f"Failed to download {config}: {err}") from err - except Exception as err: - # Catch-all for any other exceptions - raise err - - -def get_multi_line_input() -> Optional[str]: - print("Give me an instruction (Ctrl + D to submit): ") - instruction = "" - for line in sys.stdin: - instruction += line # pylint: disable=consider-using-join - # instruction = pathlib.Path("/proc/self/fd/0").read_text() - return instruction - - -def do_merge_lora( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - - LOG.info("running merge of LoRA with base model") - model = model.merge_and_unload(progressbar=True) - try: - model.to(dtype=cfg.torch_dtype) - except RuntimeError: - pass - model.generation_config.do_sample = True - - if cfg.local_rank == 0: - LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") - model.save_pretrained( - str(Path(cfg.output_dir) / "merged"), - safe_serialization=safe_serialization, - progressbar=True, - ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) - - -def do_inference( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": - chat_template_str = get_chat_template_from_config( - cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer - ) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - while True: - print("=" * 80) - # support for multiline inputs - instruction = get_multi_line_input() - if not instruction: - return - - if prompter_module: - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - print("=" * 40) - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=1024, - temperature=0.9, - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextStreamer(tokenizer) - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - generation_config=generation_config, - streamer=streamer, - ) - print("=" * 40) - print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) - - -def do_inference_gradio( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - import gradio as gr - - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - prompter = cli_args.prompter - - prompter_module = None - chat_template_str = None - if prompter: - prompter_module = getattr( - importlib.import_module("axolotl.prompters"), prompter - ) - elif cfg.chat_template: - chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) - - model = model.to(cfg.device, dtype=cfg.torch_dtype) - - def generate(instruction): - if not instruction: - return - if prompter_module: - # pylint: disable=stop-iteration-return - prompt: str = next( - prompter_module().build_prompt(instruction=instruction.strip("\n")) - ) - else: - prompt = instruction.strip() - - if chat_template_str: - batch = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": prompt, - } - ], - return_tensors="pt", - add_special_tokens=True, - add_generation_prompt=True, - chat_template=chat_template_str, - tokenize=True, - return_dict=True, - ) - else: - batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - - model.eval() - with torch.no_grad(): - generation_config = GenerationConfig( - repetition_penalty=1.1, - max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), - temperature=cfg.get("gradio_temperature", 0.9), - top_p=0.95, - top_k=40, - bos_token_id=tokenizer.bos_token_id, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - do_sample=True, - use_cache=True, - return_dict_in_generate=True, - output_attentions=False, - output_hidden_states=False, - output_scores=False, - ) - streamer = TextIteratorStreamer(tokenizer) - generation_kwargs = { - "inputs": batch["input_ids"].to(cfg.device), - "attention_mask": batch["attention_mask"].to(cfg.device), - "generation_config": generation_config, - "streamer": streamer, - } - - thread = Thread(target=model.generate, kwargs=generation_kwargs) - thread.start() - - all_text = "" - - for new_text in streamer: - all_text += new_text - yield all_text - - demo = gr.Interface( - fn=generate, - inputs="textbox", - outputs="text", - title=cfg.get("gradio_title", "Axolotl Gradio Interface"), - ) - - demo.queue().launch( - show_api=False, - share=cfg.get("gradio_share", True), - server_name=cfg.get("gradio_server_name", "127.0.0.1"), - server_port=cfg.get("gradio_server_port", None), - ) - - -def choose_config(path: Path): - yaml_files = list(path.glob("*.yml")) - - if not yaml_files: - raise ValueError( - "No YAML config files found in the specified directory. Are you using a .yml extension?" - ) - - if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") - return str(yaml_files[0]) - - print("Choose a YAML file:") - for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") - - chosen_file = None - while chosen_file is None: - try: - choice = int(input("Enter the number of your choice: ")) - if 1 <= choice <= len(yaml_files): - chosen_file = str(yaml_files[choice - 1]) - else: - print("Invalid choice. Please choose a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") - - return chosen_file - - -def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: - return not any(el in list2 for el in list1) - - -def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): - config = check_remote_config(config) - if Path(config).is_dir(): - config = choose_config(Path(config)) - - # load the config from the yaml file - with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.safe_load(file)) - # if there are any options passed in the cli, if it is something that seems valid from the yaml, - # then overwrite the value - cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) - else: - cfg[k] = kwargs[k] - - cfg.axolotl_config_path = config - - try: - device_props = torch.cuda.get_device_properties("cuda") - gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) - except: # pylint: disable=bare-except # noqa: E722 - gpu_version = None - - prepare_plugins(cfg) - - cfg = validate_config( - cfg, - capabilities={ - "bf16": is_torch_bf16_gpu_available(), - "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), - "compute_capability": gpu_version, - }, - env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], - }, - ) - - prepare_optim_env(cfg) - - prepare_opinionated_env(cfg) - - normalize_config(cfg) - - normalize_cfg_datasets(cfg) - - setup_wandb_env_vars(cfg) - - setup_mlflow_env_vars(cfg) - - setup_comet_env_vars(cfg) - - return cfg - - -def load_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -) -> TrainDatasetMeta: - tokenizer = load_tokenizer(cfg) - processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - - train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( - cfg, - tokenizer, - processor=processor, - ) - - if ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ): - LOG.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - ) - - LOG.info("printing prompters...") - for prompter in prompters: - LOG.info(prompter) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -def load_rl_datasets( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, # pylint: disable=unused-argument -) -> TrainDatasetMeta: - train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - - if cli_args.debug or cfg.debug: - LOG.info("check_dataset_labels...") - - tokenizer = load_tokenizer(cfg) - check_dataset_labels( - train_dataset.select( - [ - random.randrange(0, len(train_dataset) - 1) # nosec - for _ in range(cli_args.debug_num_examples) - ] - ), - tokenizer, - num_examples=cli_args.debug_num_examples, - text_only=cli_args.debug_text_only, - rl_mode=True, - ) - - return TrainDatasetMeta( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - total_num_steps=total_num_steps, - ) - - -======= ->>>>>>> 73d65961 (CLI init refactor) -def check_accelerate_default_config(): - if Path(config_args.default_yaml_config_file).exists(): - LOG.warning( - f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" - ) - - -def check_user_token(): - # Skip check if HF_HUB_OFFLINE is set to True - if os.getenv("HF_HUB_OFFLINE") == "1": - LOG.info( - "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." - ) - return True - - # Verify if token is valid - api = HfApi() - try: - user_info = api.whoami() - return bool(user_info) - except LocalTokenNotFoundError: - LOG.warning( - "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." - ) - return False diff --git a/src/axolotl/cli/art.py b/src/axolotl/cli/art.py new file mode 100644 index 000000000..f8331cff5 --- /dev/null +++ b/src/axolotl/cli/art.py @@ -0,0 +1,50 @@ +"""Axolotl ASCII logo utils.""" + +from art import text2art +from transformers.utils.import_utils import _is_package_available + +from axolotl.utils.distributed import is_main_process + +AXOLOTL_LOGO = """ + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ +""" + + +def print_dep_versions(): + packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] + max_len = max(len(pkg) for pkg in packages) + if is_main_process(): + print("*" * 40) + print("**** Axolotl Dependency Versions *****") + for pkg in packages: + pkg_version = _is_package_available(pkg, return_version=True) + print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}") + print("*" * 40) + + +def print_legacy_axolotl_text_art(suffix=None): + font = "nancyj" + ascii_text = " axolotl" + if suffix: + ascii_text += f" x {suffix}" + ascii_art = text2art(ascii_text, font=font) + + if is_main_process(): + print(ascii_art) + + print_dep_versions() + + +def print_axolotl_text_art(): + if is_main_process(): + print(AXOLOTL_LOGO) diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py new file mode 100644 index 000000000..1d834691e --- /dev/null +++ b/src/axolotl/cli/checks.py @@ -0,0 +1,41 @@ +"""Various checks for Axolotl CLI.""" + +import logging +import os +from pathlib import Path + +from accelerate.commands.config import config_args +from huggingface_hub import HfApi +from huggingface_hub.utils import LocalTokenNotFoundError + +from axolotl.logging_config import configure_logging + +configure_logging() +LOG = logging.getLogger(__name__) + + +def check_accelerate_default_config(): + if Path(config_args.default_yaml_config_file).exists(): + LOG.warning( + f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" + ) + + +def check_user_token(): + # Skip check if HF_HUB_OFFLINE is set to True + if os.getenv("HF_HUB_OFFLINE") == "1": + LOG.info( + "Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used." + ) + return True + + # Verify if token is valid + api = HfApi() + try: + user_info = api.whoami() + return bool(user_info) + except LocalTokenNotFoundError: + LOG.warning( + "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." + ) + return False diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 9ae830808..44f628e6a 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -1,4 +1,5 @@ """Configuration loading and processing.""" + import json import logging import os @@ -24,7 +25,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars -LOG = logging.getLogger("axolotl.cli.config") +LOG = logging.getLogger(__name__) def check_remote_config(config: Union[str, Path]): diff --git a/src/axolotl/cli/datasets.py b/src/axolotl/cli/datasets.py index e7884eb90..e98321c0b 100644 --- a/src/axolotl/cli/datasets.py +++ b/src/axolotl/cli/datasets.py @@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels -LOG = logging.getLogger("axolotl.scripts") +LOG = logging.getLogger(__name__) def load_datasets( diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b..17e2e8dc3 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run evaluation on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,18 +8,14 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, - print_axolotl_text_art, -) +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token +from axolotl.cli.config import load_cfg +from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.evaluate import evaluate -LOG = logging.getLogger("axolotl.cli.evaluate") +LOG = logging.getLogger(__name__) def do_evaluate(cfg, cli_args) -> None: diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 8fb05b9d4..be6ba8236 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,4 +1,5 @@ """CLI to run inference on a trained model.""" + import importlib import logging import sys @@ -12,7 +13,7 @@ import transformers from dotenv import load_dotenv from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from axolotl.cli import print_axolotl_text_art +from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.utils.chat_templates import ( @@ -21,7 +22,7 @@ from axolotl.utils.chat_templates import ( ) from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.cli.inference") +LOG = logging.getLogger(__name__) def get_multi_line_input() -> Optional[str]: diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..30fdb6ad7 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,5 +1,6 @@ -"""CLI definition for various axolotl commands.""" +"""Click CLI definitions for various axolotl commands.""" # pylint: disable=redefined-outer-name + import subprocess # nosec B404 from typing import Optional diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index bf30d9c33..42e315bfe 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,5 @@ -""" -CLI to run merge a trained LoRA into a base model -""" +"""CLI to merge a trained LoRA into a base model.""" + import logging from pathlib import Path from typing import Union @@ -9,7 +8,7 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import print_axolotl_text_art +from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.utils.dict import DictDefault diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 152fa08f0..0ecf7e70f 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -1,6 +1,5 @@ -""" -This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint -""" +"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" + import json import logging import os @@ -25,11 +24,11 @@ from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner -from axolotl.cli import print_axolotl_text_art +from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs -LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") +LOG = logging.getLogger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index efb2a1809..182abfaa4 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run preprocessing of a dataset.""" + import logging import warnings from pathlib import Path @@ -13,18 +12,15 @@ from colorama import Fore from dotenv import load_dotenv from transformers import AutoModelForCausalLM -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - print_axolotl_text_art, -) +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.utils.trainer import disable_datasets_caching -LOG = logging.getLogger("axolotl.cli.preprocess") +LOG = logging.getLogger(__name__) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index 658442992..86533db7e 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -1,6 +1,5 @@ -""" -CLI to shard a trained model into 10GiB chunks -""" +"""CLI to shard a trained model into 10GiB chunks.""" + import logging from pathlib import Path from typing import Union @@ -9,12 +8,12 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import print_axolotl_text_art +from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.utils.dict import DictDefault -LOG = logging.getLogger("axolotl.scripts") +LOG = logging.getLogger(__name__) def shard( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 41f4b2d83..eb70cfaa1 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,6 +1,5 @@ -""" -CLI to run training on a model -""" +"""CLI to run training on a model.""" + import logging from pathlib import Path from typing import Union @@ -9,18 +8,15 @@ import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from axolotl.cli import ( - check_accelerate_default_config, - check_user_token, - print_axolotl_text_art, -) +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.integrations.base import PluginManager from axolotl.train import train -LOG = logging.getLogger("axolotl.cli.train") +LOG = logging.getLogger(__name__) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 85d241b5d..ecd943f24 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -1,4 +1,5 @@ -"""Utility methods for axoltl CLI.""" +"""Utility methods for axolotl CLI.""" + import concurrent.futures import dataclasses import hashlib @@ -12,11 +13,16 @@ import click import requests from pydantic import BaseModel -LOG = logging.getLogger("axolotl.cli.utils") +LOG = logging.getLogger(__name__) def add_options_from_dataclass(config_class: Type[Any]): - """Create Click options from the fields of a dataclass.""" + """ + Create Click options from the fields of a dataclass. + + Args: + config_class: Dataclass with fields to parse from the CLI + """ def decorator(function): # Process dataclass fields in reverse order for correct option ordering @@ -49,7 +55,12 @@ def add_options_from_dataclass(config_class: Type[Any]): def add_options_from_config(config_class: Type[BaseModel]): - """Create Click options from the fields of a Pydantic model.""" + """ + Create Click options from the fields of a Pydantic model. + + Args: + config_class: PyDantic model with fields to parse from the CLI + """ def decorator(function): # Process model fields in reverse order for correct option ordering @@ -71,7 +82,16 @@ def add_options_from_config(config_class: Type[BaseModel]): def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: - """Build command list from base command and options.""" + """ + Build command list from base command and options. + + Args: + base_cmd: Command without options + options: Options to parse and append to base command + + Returns: + List of strings giving shell command + """ cmd = base_cmd.copy() for key, value in options.items():