cleanup and (partial) docs
This commit is contained in:
@@ -5,15 +5,12 @@ from pathlib import Path
|
|||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
check_accelerate_default_config,
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
check_user_token,
|
from axolotl.cli.config import load_cfg
|
||||||
do_inference,
|
from axolotl.cli.datasets import load_datasets
|
||||||
do_merge_lora,
|
from axolotl.cli.inference import do_inference
|
||||||
load_cfg,
|
from axolotl.cli.merge_lora import do_merge_lora
|
||||||
load_datasets,
|
|
||||||
print_axolotl_text_art,
|
|
||||||
)
|
|
||||||
from axolotl.cli.shard import shard
|
from axolotl.cli.shard import shard
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|||||||
@@ -1,536 +1,5 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Axolotl CLI module initialization."""
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
|
||||||
from accelerate.commands.config import config_args
|
|
||||||
from art import text2art
|
|
||||||
from huggingface_hub import HfApi
|
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
|
||||||
from transformers.utils.import_utils import _is_package_available
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
src_dir = os.path.join(project_root, "src")
|
|
||||||
sys.path.insert(0, src_dir)
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = logging.getLogger("axolotl.scripts")
|
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||||
|
|
||||||
AXOLOTL_LOGO = """
|
|
||||||
#@@ #@@ @@# @@#
|
|
||||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
|
||||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
|
||||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
|
||||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
|
||||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
|
||||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
|
||||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
|
||||||
@@@@ @@@@@@@@@@@@@@@@
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def print_legacy_axolotl_text_art(suffix=None):
|
|
||||||
font = "nancyj"
|
|
||||||
ascii_text = " axolotl"
|
|
||||||
if suffix:
|
|
||||||
ascii_text += f" x {suffix}"
|
|
||||||
ascii_art = text2art(ascii_text, font=font)
|
|
||||||
|
|
||||||
if is_main_process():
|
|
||||||
print(ascii_art)
|
|
||||||
|
|
||||||
print_dep_versions()
|
|
||||||
|
|
||||||
|
|
||||||
def print_axolotl_text_art(
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if is_main_process():
|
|
||||||
print(AXOLOTL_LOGO)
|
|
||||||
|
|
||||||
|
|
||||||
def print_dep_versions():
|
|
||||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
|
||||||
max_len = max(len(pkg) for pkg in packages)
|
|
||||||
if is_main_process():
|
|
||||||
print("*" * 40)
|
|
||||||
print("**** Axolotl Dependency Versions *****")
|
|
||||||
for pkg in packages:
|
|
||||||
pkg_version = _is_package_available(pkg, return_version=True)
|
|
||||||
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
|
||||||
print("*" * 40)
|
|
||||||
|
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
def check_remote_config(config: Union[str, Path]):
|
|
||||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
|
||||||
if not (isinstance(config, str) and config.startswith("https://")):
|
|
||||||
return config # Return the original value if it's not a valid URL
|
|
||||||
|
|
||||||
filename = os.path.basename(urlparse(config).path)
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(config, timeout=30)
|
|
||||||
response.raise_for_status() # Check for HTTP errors
|
|
||||||
|
|
||||||
content = response.content
|
|
||||||
try:
|
|
||||||
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
|
||||||
json.loads(content)
|
|
||||||
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
|
||||||
LOG.warning(
|
|
||||||
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If it's not valid JSON, verify it's valid YAML
|
|
||||||
try:
|
|
||||||
yaml.safe_load(content)
|
|
||||||
except yaml.YAMLError as err:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to parse the content at {config} as YAML: {err}"
|
|
||||||
) from err
|
|
||||||
|
|
||||||
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
|
|
||||||
output_path = Path(temp_dir) / filename
|
|
||||||
with open(output_path, "wb") as file:
|
|
||||||
file.write(content)
|
|
||||||
LOG.info(
|
|
||||||
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
|
|
||||||
)
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
except requests.RequestException as err:
|
|
||||||
# This catches all requests-related exceptions including HTTPError
|
|
||||||
raise RuntimeError(f"Failed to download {config}: {err}") from err
|
|
||||||
except Exception as err:
|
|
||||||
# Catch-all for any other exceptions
|
|
||||||
raise err
|
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
|
||||||
print("Give me an instruction (Ctrl + D to submit): ")
|
|
||||||
instruction = ""
|
|
||||||
for line in sys.stdin:
|
|
||||||
instruction += line # pylint: disable=consider-using-join
|
|
||||||
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
|
||||||
return instruction
|
|
||||||
|
|
||||||
|
|
||||||
def do_merge_lora(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
|
||||||
model = model.merge_and_unload(progressbar=True)
|
|
||||||
try:
|
|
||||||
model.to(dtype=cfg.torch_dtype)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
model.generation_config.do_sample = True
|
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
|
||||||
model.save_pretrained(
|
|
||||||
str(Path(cfg.output_dir) / "merged"),
|
|
||||||
safe_serialization=safe_serialization,
|
|
||||||
progressbar=True,
|
|
||||||
)
|
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
|
||||||
|
|
||||||
|
|
||||||
def do_inference(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
prompter = cli_args.prompter
|
|
||||||
|
|
||||||
prompter_module = None
|
|
||||||
chat_template_str = None
|
|
||||||
if prompter:
|
|
||||||
prompter_module = getattr(
|
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
|
||||||
)
|
|
||||||
elif cfg.chat_template:
|
|
||||||
chat_template_str = get_chat_template(cfg.chat_template)
|
|
||||||
elif cfg.datasets[0].type == "chat_template":
|
|
||||||
chat_template_str = get_chat_template_from_config(
|
|
||||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
print("=" * 80)
|
|
||||||
# support for multiline inputs
|
|
||||||
instruction = get_multi_line_input()
|
|
||||||
if not instruction:
|
|
||||||
return
|
|
||||||
|
|
||||||
if prompter_module:
|
|
||||||
prompt: str = next(
|
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = instruction.strip()
|
|
||||||
|
|
||||||
if chat_template_str:
|
|
||||||
batch = tokenizer.apply_chat_template(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=True,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
chat_template=chat_template_str,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
|
||||||
print("=" * 40)
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
generation_config = GenerationConfig(
|
|
||||||
repetition_penalty=1.1,
|
|
||||||
max_new_tokens=1024,
|
|
||||||
temperature=0.9,
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=40,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
output_scores=False,
|
|
||||||
)
|
|
||||||
streamer = TextStreamer(tokenizer)
|
|
||||||
generated = model.generate(
|
|
||||||
inputs=batch["input_ids"].to(cfg.device),
|
|
||||||
generation_config=generation_config,
|
|
||||||
streamer=streamer,
|
|
||||||
)
|
|
||||||
print("=" * 40)
|
|
||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def do_inference_gradio(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
):
|
|
||||||
import gradio as gr
|
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
prompter = cli_args.prompter
|
|
||||||
|
|
||||||
prompter_module = None
|
|
||||||
chat_template_str = None
|
|
||||||
if prompter:
|
|
||||||
prompter_module = getattr(
|
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
|
||||||
)
|
|
||||||
elif cfg.chat_template:
|
|
||||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
|
||||||
|
|
||||||
def generate(instruction):
|
|
||||||
if not instruction:
|
|
||||||
return
|
|
||||||
if prompter_module:
|
|
||||||
# pylint: disable=stop-iteration-return
|
|
||||||
prompt: str = next(
|
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = instruction.strip()
|
|
||||||
|
|
||||||
if chat_template_str:
|
|
||||||
batch = tokenizer.apply_chat_template(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
return_tensors="pt",
|
|
||||||
add_special_tokens=True,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
chat_template=chat_template_str,
|
|
||||||
tokenize=True,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
generation_config = GenerationConfig(
|
|
||||||
repetition_penalty=1.1,
|
|
||||||
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
|
||||||
temperature=cfg.get("gradio_temperature", 0.9),
|
|
||||||
top_p=0.95,
|
|
||||||
top_k=40,
|
|
||||||
bos_token_id=tokenizer.bos_token_id,
|
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
do_sample=True,
|
|
||||||
use_cache=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
output_scores=False,
|
|
||||||
)
|
|
||||||
streamer = TextIteratorStreamer(tokenizer)
|
|
||||||
generation_kwargs = {
|
|
||||||
"inputs": batch["input_ids"].to(cfg.device),
|
|
||||||
"attention_mask": batch["attention_mask"].to(cfg.device),
|
|
||||||
"generation_config": generation_config,
|
|
||||||
"streamer": streamer,
|
|
||||||
}
|
|
||||||
|
|
||||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
all_text = ""
|
|
||||||
|
|
||||||
for new_text in streamer:
|
|
||||||
all_text += new_text
|
|
||||||
yield all_text
|
|
||||||
|
|
||||||
demo = gr.Interface(
|
|
||||||
fn=generate,
|
|
||||||
inputs="textbox",
|
|
||||||
outputs="text",
|
|
||||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
|
||||||
)
|
|
||||||
|
|
||||||
demo.queue().launch(
|
|
||||||
show_api=False,
|
|
||||||
share=cfg.get("gradio_share", True),
|
|
||||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
|
||||||
server_port=cfg.get("gradio_server_port", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
|
||||||
yaml_files = list(path.glob("*.yml"))
|
|
||||||
|
|
||||||
if not yaml_files:
|
|
||||||
raise ValueError(
|
|
||||||
"No YAML config files found in the specified directory. Are you using a .yml extension?"
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(yaml_files) == 1:
|
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
|
||||||
return str(yaml_files[0])
|
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
|
||||||
for idx, file in enumerate(yaml_files):
|
|
||||||
print(f"{idx + 1}. {file}")
|
|
||||||
|
|
||||||
chosen_file = None
|
|
||||||
while chosen_file is None:
|
|
||||||
try:
|
|
||||||
choice = int(input("Enter the number of your choice: "))
|
|
||||||
if 1 <= choice <= len(yaml_files):
|
|
||||||
chosen_file = str(yaml_files[choice - 1])
|
|
||||||
else:
|
|
||||||
print("Invalid choice. Please choose a number from the list.")
|
|
||||||
except ValueError:
|
|
||||||
print("Invalid input. Please enter a number.")
|
|
||||||
|
|
||||||
return chosen_file
|
|
||||||
|
|
||||||
|
|
||||||
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
|
||||||
return not any(el in list2 for el in list1)
|
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|
||||||
config = check_remote_config(config)
|
|
||||||
if Path(config).is_dir():
|
|
||||||
config = choose_config(Path(config))
|
|
||||||
|
|
||||||
# load the config from the yaml file
|
|
||||||
with open(config, encoding="utf-8") as file:
|
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
|
||||||
# then overwrite the value
|
|
||||||
cfg_keys = cfg.keys()
|
|
||||||
for k, _ in kwargs.items():
|
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
|
||||||
if k in cfg_keys or not cfg.strict:
|
|
||||||
# handle booleans
|
|
||||||
if isinstance(cfg[k], bool):
|
|
||||||
cfg[k] = bool(kwargs[k])
|
|
||||||
else:
|
|
||||||
cfg[k] = kwargs[k]
|
|
||||||
|
|
||||||
cfg.axolotl_config_path = config
|
|
||||||
|
|
||||||
try:
|
|
||||||
device_props = torch.cuda.get_device_properties("cuda")
|
|
||||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
|
||||||
except: # pylint: disable=bare-except # noqa: E722
|
|
||||||
gpu_version = None
|
|
||||||
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
|
|
||||||
cfg = validate_config(
|
|
||||||
cfg,
|
|
||||||
capabilities={
|
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
|
||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
|
||||||
"compute_capability": gpu_version,
|
|
||||||
},
|
|
||||||
env_capabilities={
|
|
||||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
|
||||||
|
|
||||||
prepare_opinionated_env(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
|
||||||
|
|
||||||
normalize_cfg_datasets(cfg)
|
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
|
||||||
|
|
||||||
setup_mlflow_env_vars(cfg)
|
|
||||||
|
|
||||||
setup_comet_env_vars(cfg)
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs,
|
|
||||||
) -> TrainDatasetMeta:
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
|
||||||
cfg,
|
|
||||||
tokenizer,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cli_args.debug
|
|
||||||
or cfg.debug
|
|
||||||
or cli_args.debug_text_only
|
|
||||||
or int(cli_args.debug_num_examples) > 0
|
|
||||||
):
|
|
||||||
LOG.info("check_dataset_labels...")
|
|
||||||
check_dataset_labels(
|
|
||||||
train_dataset.select(
|
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
num_examples=cli_args.debug_num_examples,
|
|
||||||
text_only=cli_args.debug_text_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
|
||||||
for prompter in prompters:
|
|
||||||
LOG.info(prompter)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_rl_datasets(
|
|
||||||
*,
|
|
||||||
cfg: DictDefault,
|
|
||||||
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
|
||||||
) -> TrainDatasetMeta:
|
|
||||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
|
||||||
total_num_steps = int(
|
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
|
||||||
LOG.info("check_dataset_labels...")
|
|
||||||
|
|
||||||
tokenizer = load_tokenizer(cfg)
|
|
||||||
check_dataset_labels(
|
|
||||||
train_dataset.select(
|
|
||||||
[
|
|
||||||
random.randrange(0, len(train_dataset) - 1) # nosec
|
|
||||||
for _ in range(cli_args.debug_num_examples)
|
|
||||||
]
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
num_examples=cli_args.debug_num_examples,
|
|
||||||
text_only=cli_args.debug_text_only,
|
|
||||||
rl_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
total_num_steps=total_num_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
=======
|
|
||||||
>>>>>>> 73d65961 (CLI init refactor)
|
|
||||||
def check_accelerate_default_config():
|
|
||||||
if Path(config_args.default_yaml_config_file).exists():
|
|
||||||
LOG.warning(
|
|
||||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_user_token():
|
|
||||||
# Skip check if HF_HUB_OFFLINE is set to True
|
|
||||||
if os.getenv("HF_HUB_OFFLINE") == "1":
|
|
||||||
LOG.info(
|
|
||||||
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Verify if token is valid
|
|
||||||
api = HfApi()
|
|
||||||
try:
|
|
||||||
user_info = api.whoami()
|
|
||||||
return bool(user_info)
|
|
||||||
except LocalTokenNotFoundError:
|
|
||||||
LOG.warning(
|
|
||||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|||||||
50
src/axolotl/cli/art.py
Normal file
50
src/axolotl/cli/art.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""Axolotl ASCII logo utils."""
|
||||||
|
|
||||||
|
from art import text2art
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
AXOLOTL_LOGO = """
|
||||||
|
#@@ #@@ @@# @@#
|
||||||
|
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||||
|
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||||
|
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||||
|
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||||
|
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||||
|
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||||
|
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||||
|
@@@@ @@@@@@@@@@@@@@@@
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def print_dep_versions():
|
||||||
|
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||||
|
max_len = max(len(pkg) for pkg in packages)
|
||||||
|
if is_main_process():
|
||||||
|
print("*" * 40)
|
||||||
|
print("**** Axolotl Dependency Versions *****")
|
||||||
|
for pkg in packages:
|
||||||
|
pkg_version = _is_package_available(pkg, return_version=True)
|
||||||
|
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||||
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
|
def print_legacy_axolotl_text_art(suffix=None):
|
||||||
|
font = "nancyj"
|
||||||
|
ascii_text = " axolotl"
|
||||||
|
if suffix:
|
||||||
|
ascii_text += f" x {suffix}"
|
||||||
|
ascii_art = text2art(ascii_text, font=font)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
print(ascii_art)
|
||||||
|
|
||||||
|
print_dep_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def print_axolotl_text_art():
|
||||||
|
if is_main_process():
|
||||||
|
print(AXOLOTL_LOGO)
|
||||||
41
src/axolotl/cli/checks.py
Normal file
41
src/axolotl/cli/checks.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Various checks for Axolotl CLI."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from accelerate.commands.config import config_args
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_accelerate_default_config():
|
||||||
|
if Path(config_args.default_yaml_config_file).exists():
|
||||||
|
LOG.warning(
|
||||||
|
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_token():
|
||||||
|
# Skip check if HF_HUB_OFFLINE is set to True
|
||||||
|
if os.getenv("HF_HUB_OFFLINE") == "1":
|
||||||
|
LOG.info(
|
||||||
|
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Verify if token is valid
|
||||||
|
api = HfApi()
|
||||||
|
try:
|
||||||
|
user_info = api.whoami()
|
||||||
|
return bool(user_info)
|
||||||
|
except LocalTokenNotFoundError:
|
||||||
|
LOG.warning(
|
||||||
|
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
|
)
|
||||||
|
return False
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Configuration loading and processing."""
|
"""Configuration loading and processing."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -24,7 +25,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
|||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.config")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]):
|
def check_remote_config(config: Union[str, Path]):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.scripts")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to run evaluation on a model."""
|
||||||
CLI to run training on a model
|
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -9,18 +8,14 @@ import fire
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
check_accelerate_default_config,
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
check_user_token,
|
from axolotl.cli.config import load_cfg
|
||||||
load_cfg,
|
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
||||||
load_datasets,
|
|
||||||
load_rl_datasets,
|
|
||||||
print_axolotl_text_art,
|
|
||||||
)
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.evaluate import evaluate
|
from axolotl.evaluate import evaluate
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.evaluate")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_evaluate(cfg, cli_args) -> None:
|
def do_evaluate(cfg, cli_args) -> None:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""CLI to run inference on a trained model."""
|
"""CLI to run inference on a trained model."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
@@ -12,7 +13,7 @@ import transformers
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -21,7 +22,7 @@ from axolotl.utils.chat_templates import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.inference")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""CLI definition for various axolotl commands."""
|
"""Click CLI definitions for various axolotl commands."""
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import subprocess # nosec B404
|
import subprocess # nosec B404
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to merge a trained LoRA into a base model."""
|
||||||
CLI to run merge a trained LoRA into a base model
|
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -9,7 +8,7 @@ import fire
|
|||||||
import transformers
|
import transformers
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
|
||||||
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
|
|
||||||
"""
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -25,11 +24,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to run preprocessing of a dataset."""
|
||||||
CLI to run training on a model
|
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -13,18 +12,15 @@ from colorama import Fore
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
check_accelerate_default_config,
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
check_user_token,
|
|
||||||
print_axolotl_text_art,
|
|
||||||
)
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to shard a trained model into 10GiB chunks."""
|
||||||
CLI to shard a trained model into 10GiB chunks
|
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -9,12 +8,12 @@ import fire
|
|||||||
import transformers
|
import transformers
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli import print_axolotl_text_art
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.scripts")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def shard(
|
def shard(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""CLI to run training on a model."""
|
||||||
CLI to run training on a model
|
|
||||||
"""
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -9,18 +8,15 @@ import fire
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
check_accelerate_default_config,
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
check_user_token,
|
|
||||||
print_axolotl_text_art,
|
|
||||||
)
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Utility methods for axoltl CLI."""
|
"""Utility methods for axolotl CLI."""
|
||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -12,11 +13,16 @@ import click
|
|||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.utils")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def add_options_from_dataclass(config_class: Type[Any]):
|
def add_options_from_dataclass(config_class: Type[Any]):
|
||||||
"""Create Click options from the fields of a dataclass."""
|
"""
|
||||||
|
Create Click options from the fields of a dataclass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class: Dataclass with fields to parse from the CLI
|
||||||
|
"""
|
||||||
|
|
||||||
def decorator(function):
|
def decorator(function):
|
||||||
# Process dataclass fields in reverse order for correct option ordering
|
# Process dataclass fields in reverse order for correct option ordering
|
||||||
@@ -49,7 +55,12 @@ def add_options_from_dataclass(config_class: Type[Any]):
|
|||||||
|
|
||||||
|
|
||||||
def add_options_from_config(config_class: Type[BaseModel]):
|
def add_options_from_config(config_class: Type[BaseModel]):
|
||||||
"""Create Click options from the fields of a Pydantic model."""
|
"""
|
||||||
|
Create Click options from the fields of a Pydantic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class: PyDantic model with fields to parse from the CLI
|
||||||
|
"""
|
||||||
|
|
||||||
def decorator(function):
|
def decorator(function):
|
||||||
# Process model fields in reverse order for correct option ordering
|
# Process model fields in reverse order for correct option ordering
|
||||||
@@ -71,7 +82,16 @@ def add_options_from_config(config_class: Type[BaseModel]):
|
|||||||
|
|
||||||
|
|
||||||
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
||||||
"""Build command list from base command and options."""
|
"""
|
||||||
|
Build command list from base command and options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_cmd: Command without options
|
||||||
|
options: Options to parse and append to base command
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of strings giving shell command
|
||||||
|
"""
|
||||||
cmd = base_cmd.copy()
|
cmd = base_cmd.copy()
|
||||||
|
|
||||||
for key, value in options.items():
|
for key, value in options.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user