From 541f9b39ff3668453f126112b4fa1a889b83e7ec Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 6 Dec 2024 11:57:53 -0500 Subject: [PATCH] CLI init refactor --- src/axolotl/cli/__init__.py | 38 +-- src/axolotl/cli/config.py | 167 +++++++++++++ src/axolotl/cli/datasets.py | 92 ++++++++ src/axolotl/cli/inference.py | 220 +++++++++++++++++- src/axolotl/cli/merge_lora.py | 35 ++- src/axolotl/cli/merge_sharded_fsdp_weights.py | 3 +- src/axolotl/cli/preprocess.py | 5 +- src/axolotl/cli/shard.py | 3 +- src/axolotl/cli/train.py | 5 +- 9 files changed, 513 insertions(+), 55 deletions(-) create mode 100644 src/axolotl/cli/config.py create mode 100644 src/axolotl/cli/datasets.py diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..26d25258d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -1,53 +1,18 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" - -import importlib -import json import logging -import math import os -import random import sys -import tempfile from pathlib import Path -from threading import Thread -from typing import Any, Dict, List, Optional, Union -from urllib.parse import urlparse - -import requests -import torch -import yaml # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError -from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.import_utils import _is_package_available -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import ( - get_chat_template, - get_chat_template_from_config, -) -from axolotl.utils.comet_ import setup_comet_env_vars -from axolotl.utils.config import ( - normalize_cfg_datasets, - normalize_config, - prepare_plugins, - validate_config, -) -from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset -from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process -from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import load_processor, load_tokenizer -from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env -from axolotl.utils.wandb_ import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -105,6 +70,7 @@ def print_dep_versions(): print("*" * 40) +<<<<<<< HEAD def check_remote_config(config: Union[str, Path]): # Check if the config is a valid HTTPS URL to a .yml or .yaml file if not (isinstance(config, str) and config.startswith("https://")): @@ -541,6 +507,8 @@ def load_rl_datasets( ) +======= +>>>>>>> 73d65961 (CLI init refactor) def check_accelerate_default_config(): if Path(config_args.default_yaml_config_file).exists(): LOG.warning( diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py new file mode 100644 index 000000000..9ae830808 --- /dev/null +++ b/src/axolotl/cli/config.py @@ -0,0 +1,167 @@ +"""Configuration loading and processing.""" +import json +import logging +import os +import tempfile +from pathlib import Path +from typing import Union +from urllib.parse import urlparse + +import requests +import torch +import yaml +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.integrations.base import PluginManager +from axolotl.utils.comet_ import setup_comet_env_vars +from axolotl.utils.config import ( + normalize_cfg_datasets, + normalize_config, + validate_config, +) +from axolotl.utils.dict import DictDefault +from axolotl.utils.mlflow_ import setup_mlflow_env_vars +from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env +from axolotl.utils.wandb_ import setup_wandb_env_vars + +LOG = logging.getLogger("axolotl.cli.config") + + +def check_remote_config(config: Union[str, Path]): + # Check if the config is a valid HTTPS URL to a .yml or .yaml file + if not (isinstance(config, str) and config.startswith("https://")): + return config # Return the original value if it's not a valid URL + + filename = os.path.basename(urlparse(config).path) + temp_dir = tempfile.mkdtemp() + + try: + response = requests.get(config, timeout=30) + response.raise_for_status() # Check for HTTP errors + + content = response.content + try: + # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML + json.loads(content) + # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link + LOG.warning( + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." + ) + except json.JSONDecodeError: + # If it's not valid JSON, verify it's valid YAML + try: + yaml.safe_load(content) + except yaml.YAMLError as err: + raise ValueError( + f"Failed to parse the content at {config} as YAML: {err}" + ) from err + + # Write the content to a file if it's valid YAML (or JSON treated as YAML) + output_path = Path(temp_dir) / filename + with open(output_path, "wb") as file: + file.write(content) + LOG.info( + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" + ) + return output_path + + except requests.RequestException as err: + # This catches all requests-related exceptions including HTTPError + raise RuntimeError(f"Failed to download {config}: {err}") from err + except Exception as err: + # Catch-all for any other exceptions + raise err + + +def choose_config(path: Path): + yaml_files = list(path.glob("*.yml")) + + if not yaml_files: + raise ValueError( + "No YAML config files found in the specified directory. Are you using a .yml extension?" + ) + + if len(yaml_files) == 1: + print(f"Using default YAML file '{yaml_files[0]}'") + return str(yaml_files[0]) + + print("Choose a YAML file:") + for idx, file in enumerate(yaml_files): + print(f"{idx + 1}. {file}") + + chosen_file = None + while chosen_file is None: + try: + choice = int(input("Enter the number of your choice: ")) + if 1 <= choice <= len(yaml_files): + chosen_file = str(yaml_files[choice - 1]) + else: + print("Invalid choice. Please choose a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return chosen_file + + +def prepare_plugins(cfg): + """ + Prepare the plugins for the configuration + """ + + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + for plugin_name in cfg["plugins"]: + plugin_manager.register(plugin_name) + + +def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): + config = check_remote_config(config) + if Path(config).is_dir(): + config = choose_config(Path(config)) + + # load the config from the yaml file + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # if there are any options passed in the cli, if it is something that seems valid from the yaml, + # then overwrite the value + cfg_keys = cfg.keys() + for k, _ in kwargs.items(): + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or not cfg.strict: + # handle booleans + if isinstance(cfg[k], bool): + cfg[k] = bool(kwargs[k]) + else: + cfg[k] = kwargs[k] + + cfg.axolotl_config_path = config + + try: + device_props = torch.cuda.get_device_properties("cuda") + gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) + except: # pylint: disable=bare-except # noqa: E722 + gpu_version = None + + prepare_plugins(cfg) + + cfg = validate_config( + cfg, + capabilities={ + "bf16": is_torch_bf16_gpu_available(), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), + "compute_capability": gpu_version, + }, + env_capabilities={ + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + }, + ) + + prepare_optim_env(cfg) + prepare_opinionated_env(cfg) + normalize_config(cfg) + normalize_cfg_datasets(cfg) + setup_wandb_env_vars(cfg) + setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + + return cfg diff --git a/src/axolotl/cli/datasets.py b/src/axolotl/cli/datasets.py new file mode 100644 index 000000000..e7884eb90 --- /dev/null +++ b/src/axolotl/cli/datasets.py @@ -0,0 +1,92 @@ +"""Dataset loading utilities.""" +import logging +import math +import random + +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import TrainDatasetMeta +from axolotl.utils.data import prepare_dataset +from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_processor, load_tokenizer +from axolotl.utils.tokenization import check_dataset_labels + +LOG = logging.getLogger("axolotl.scripts") + + +def load_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +) -> TrainDatasetMeta: + tokenizer = load_tokenizer(cfg) + processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None + + train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( + cfg, + tokenizer, + processor=processor, + ) + + if ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 + ): + LOG.info("check_dataset_labels...") + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + ) + + LOG.info("printing prompters...") + for prompter in prompters: + LOG.info(prompter) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + +def load_rl_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument +) -> TrainDatasetMeta: + train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + if cli_args.debug or cfg.debug: + LOG.info("check_dataset_labels...") + + tokenizer = load_tokenizer(cfg) + check_dataset_labels( + train_dataset.select( + [ + random.randrange(0, len(train_dataset) - 1) # nosec + for _ in range(cli_args.debug_num_examples) + ] + ), + tokenizer, + num_examples=cli_args.debug_num_examples, + text_only=cli_args.debug_text_only, + rl_mode=True, + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index a5f1a8ad8..8fb05b9d4 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,20 +1,220 @@ -""" -CLI to run inference on a trained model -""" +"""CLI to run inference on a trained model.""" +import importlib +import logging +import sys from pathlib import Path -from typing import Union +from threading import Thread +from typing import Optional, Union import fire +import torch import transformers from dotenv import load_dotenv +from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer -from axolotl.cli import ( - do_inference, - do_inference_gradio, - load_cfg, - print_axolotl_text_art, +from axolotl.cli import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.utils.chat_templates import ( + get_chat_template, + get_chat_template_from_config, ) -from axolotl.common.cli import TrainerCliArgs +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.cli.inference") + + +def get_multi_line_input() -> Optional[str]: + print("Give me an instruction (Ctrl + D to submit): ") + instruction = "" + for line in sys.stdin: + instruction += line # pylint: disable=consider-using-join + # instruction = pathlib.Path("/proc/self/fd/0").read_text() + return instruction + + +def do_inference( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template) + elif cfg.datasets[0].type == "chat_template": + chat_template_str = get_chat_template_from_config( + cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer + ) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + while True: + print("=" * 80) + # support for multiline inputs + instruction = get_multi_line_input() + if not instruction: + return + + if prompter_module: + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + print("=" * 40) + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=1024, + temperature=0.9, + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextStreamer(tokenizer) + generated = model.generate( + inputs=batch["input_ids"].to(cfg.device), + generation_config=generation_config, + streamer=streamer, + ) + print("=" * 40) + print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) + + +def do_inference_gradio( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + import gradio as gr + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + prompter = cli_args.prompter + + prompter_module = None + chat_template_str = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + elif cfg.chat_template: + chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer) + + model = model.to(cfg.device, dtype=cfg.torch_dtype) + + def generate(instruction): + if not instruction: + return + if prompter_module: + # pylint: disable=stop-iteration-return + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + + if chat_template_str: + batch = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": prompt, + } + ], + return_tensors="pt", + add_special_tokens=True, + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=True, + return_dict=True, + ) + else: + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=cfg.get("gradio_max_new_tokens", 1024), + temperature=cfg.get("gradio_temperature", 0.9), + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = { + "inputs": batch["input_ids"].to(cfg.device), + "attention_mask": batch["attention_mask"].to(cfg.device), + "generation_config": generation_config, + "streamer": streamer, + } + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + all_text = "" + + for new_text in streamer: + all_text += new_text + yield all_text + + demo = gr.Interface( + fn=generate, + inputs="textbox", + outputs="text", + title=cfg.get("gradio_title", "Axolotl Gradio Interface"), + ) + + demo.queue().launch( + show_api=False, + share=cfg.get("gradio_share", True), + server_name=cfg.get("gradio_server_name", "127.0.0.1"), + server_port=cfg.get("gradio_server_port", None), + ) def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 8c321bc48..bf30d9c33 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,7 @@ """ CLI to run merge a trained LoRA into a base model """ +import logging from pathlib import Path from typing import Union @@ -8,8 +9,38 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs +from axolotl.cli import print_axolotl_text_art +from axolotl.cli.config import load_cfg +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.cli.merge_lora") + + +def do_merge_lora( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + + LOG.info("running merge of LoRA with base model") + model = model.merge_and_unload(progressbar=True) + try: + model.to(dtype=cfg.torch_dtype) + except RuntimeError: + pass + model.generation_config.do_sample = True + + if cfg.local_rank == 0: + LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") + model.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + safe_serialization=safe_serialization, + progressbar=True, + ) + tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 6be9af1f7..152fa08f0 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -25,7 +25,8 @@ from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner -from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.cli import print_axolotl_text_art +from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights") diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1592aa78..efb2a1809 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -16,11 +16,10 @@ from transformers import AutoModelForCausalLM from axolotl.cli import ( check_accelerate_default_config, check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, print_axolotl_text_art, ) +from axolotl.cli.config import load_cfg +from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.utils.trainer import disable_datasets_caching diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index 196c0e99a..658442992 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -9,7 +9,8 @@ import fire import transformers from dotenv import load_dotenv -from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.cli import print_axolotl_text_art +from axolotl.cli.config import load_cfg from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.utils.dict import DictDefault diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2a40e854e..41f4b2d83 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,11 +12,10 @@ from transformers.hf_argparser import HfArgumentParser from axolotl.cli import ( check_accelerate_default_config, check_user_token, - load_cfg, - load_datasets, - load_rl_datasets, print_axolotl_text_art, ) +from axolotl.cli.config import load_cfg +from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.common.cli import TrainerCliArgs from axolotl.integrations.base import PluginManager from axolotl.train import train