CLI init refactor

This commit is contained in:
Dan Saunders
2024-12-06 11:57:53 -05:00
committed by Dan Saunders
parent f89e962119
commit 541f9b39ff
9 changed files with 513 additions and 55 deletions

View File

@@ -1,53 +1,18 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
@@ -105,6 +70,7 @@ def print_dep_versions():
print("*" * 40)
<<<<<<< HEAD
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
@@ -541,6 +507,8 @@ def load_rl_datasets(
)
=======
>>>>>>> 73d65961 (CLI init refactor)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(

167
src/axolotl/cli/config.py Normal file
View File

@@ -0,0 +1,167 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Union
from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger("axolotl.cli.config")
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def prepare_plugins(cfg):
"""
Prepare the plugins for the configuration
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg

View File

@@ -0,0 +1,92 @@
"""Dataset loading utilities."""
import logging
import math
import random
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import TrainDatasetMeta
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger("axolotl.scripts")
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)

View File

@@ -1,20 +1,220 @@
"""
CLI to run inference on a trained model
"""
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from typing import Union
from threading import Thread
from typing import Optional, Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
from axolotl.cli import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.inference")
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):

View File

@@ -1,6 +1,7 @@
"""
CLI to run merge a trained LoRA into a base model
"""
import logging
from pathlib import Path
from typing import Union
@@ -8,8 +9,38 @@ import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.merge_lora")
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
except RuntimeError:
pass
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

View File

@@ -25,7 +25,8 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.cli import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")

View File

@@ -16,11 +16,10 @@ from transformers import AutoModelForCausalLM
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.trainer import disable_datasets_caching

View File

@@ -9,7 +9,8 @@ import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.cli import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault

View File

@@ -12,11 +12,10 @@ from transformers.hf_argparser import HfArgumentParser
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.train import train