Adding documentation and continuing cleanup (in progress)

This commit is contained in:
Dan Saunders
2025-01-07 20:16:39 +00:00
parent 324c533adb
commit 773c3b51cd
14 changed files with 283 additions and 192 deletions

View File

@@ -1,4 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora."""
import logging
from pathlib import Path
@@ -30,19 +31,19 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
do_merge_lora(cfg=parsed_cfg)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
shard(cfg=parsed_cfg)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
train(cfg=parsed_cfg, dataset_meta=dataset_meta)
if __name__ == "__main__":

View File

@@ -21,6 +21,7 @@ AXOLOTL_LOGO = """
def print_dep_versions():
"""Prints versions of various axolotl dependencies."""
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
@@ -33,6 +34,7 @@ def print_dep_versions():
def print_legacy_axolotl_text_art(suffix=None):
"""Prints axolotl ASCII art and dependency versions."""
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -46,5 +48,6 @@ def print_legacy_axolotl_text_art(suffix=None):
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process():
print(AXOLOTL_LOGO)

View File

@@ -15,6 +15,7 @@ LOG = logging.getLogger(__name__)
def check_accelerate_default_config():
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
@@ -22,6 +23,14 @@ def check_accelerate_default_config():
def check_user_token():
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
Raises:
LocalTokenNotFoundError: If HF user info can't be retrieved.
"""
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(

View File

@@ -2,8 +2,9 @@
import logging
import math
import random
from typing import Union
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.train import TrainDatasetMeta
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
@@ -17,7 +18,7 @@ LOG = logging.getLogger(__name__)
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
@@ -61,7 +62,9 @@ def load_datasets(
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
cli_args: Union[
PreprocessCliArgs, TrainerCliArgs
], # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(

View File

@@ -29,7 +29,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -5,7 +5,7 @@ import logging
import sys
from pathlib import Path
from threading import Thread
from typing import Optional, Union
from typing import Union
import fire
import torch
@@ -15,7 +15,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
@@ -25,7 +25,13 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def get_multi_line_input() -> Optional[str]:
def get_multi_line_input() -> str:
"""
Gets multi-line input from terminal.
Returns:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
@@ -37,9 +43,18 @@ def get_multi_line_input() -> Optional[str]:
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: InferenceCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
"""
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
@@ -121,11 +136,20 @@ def do_inference(
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: InferenceCliArgs,
):
"""
Runs inference in a Gradio interface. User input is accepted, a chat template is
(optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
@@ -218,16 +242,25 @@ def do_inference_gradio(
)
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None:
"""
Parses axolotl config, training-specific CLI args, and calls `do_inference` or
`do_inference_gradio` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -29,7 +29,14 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli
@@ -46,8 +53,16 @@ def preprocess(config: str, **kwargs):
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
def train(config: str, accelerate: bool, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
@@ -74,8 +89,16 @@ def train(config: str, accelerate: bool, **kwargs):
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
"""
Evaluate a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
@@ -97,46 +120,32 @@ def evaluate(config: str, accelerate: bool, **kwargs):
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
"""
Run inference with a trained model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@@ -146,20 +155,18 @@ def inference(
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
def shard(config: str, accelerate: bool, **kwargs) -> None:
"""
Shard model weights into chunks.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
@@ -181,18 +188,18 @@ def shard(config: str, accelerate: bool, **kwargs):
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
"""
Merge sharded FSDP model weights.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
@@ -214,27 +221,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_lora(config: str, **kwargs) -> None:
"""
Merge trained LoRA adapters into a base model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.merge_lora import do_cli
@@ -244,13 +243,17 @@ def merge_lora(
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
def fetch(directory: str, dest: Optional[str]) -> None:
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
Args:
directory: One of `examples`, `deepspeed_configs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)

View File

@@ -13,18 +13,23 @@ from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.merge_lora")
LOG = logging.getLogger(__name__)
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
print_axolotl_text_art()
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
@@ -33,7 +38,7 @@ def do_merge_lora(
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
@@ -42,10 +47,21 @@ def do_merge_lora(
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine.
Note that various config values will be overwritten to allow the LoRA merge logic to work
as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
@@ -76,7 +92,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
do_merge_lora(cfg=parsed_cfg)
if __name__ == "__main__":

View File

@@ -18,24 +18,25 @@ from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments.
"""
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
if not cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
@@ -43,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if cfg.rl: # and cfg.rl != "orpo":
load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_datasets(cfg=cfg, cli_args=cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
if cli_args.download:
model_name = cfg.base_model
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
@@ -69,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`"
+ Fore.RESET
)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_preprocess(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,39 +5,27 @@ from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.common.cli import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
def shard(*, cfg: DictDefault):
model, _ = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
shard(cfg=parsed_cfg)
if __name__ == "__main__":

View File

@@ -15,21 +15,21 @@ from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Trains transformer model by first loading the dataset(s) specified in the `axolotl`
config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin
manager's `post_train_unload` once training completes.
def do_train(cfg, cli_args) -> None:
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -39,7 +39,7 @@ def do_train(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
@@ -48,6 +48,24 @@ def do_train(cfg, cli_args) -> None:
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,10 +1,10 @@
"""
shared module for cli specific things
"""
"""Shared module for CLI specific utilities."""
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import Any, Optional, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
@@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -30,14 +28,11 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@@ -45,25 +40,40 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
inference: bool = False,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
inference = getattr(cli_args, "inference", False)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -9,7 +9,6 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -62,16 +61,13 @@ def evaluate_dataset(
return metrics
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
Args:
cfg: Configuration dictionary
cli_args: Command line arguments
dataset_meta: Dataset metadata containing training and evaluation datasets
cfg: Config dictionary.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Tuple containing:
@@ -102,9 +98,7 @@ def evaluate(
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(

View File

@@ -19,7 +19,6 @@ from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
@@ -39,14 +38,12 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger("axolotl.train")
LOG = get_logger(__name__)
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
@@ -54,7 +51,7 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer
LOG.debug(
@@ -93,9 +90,7 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
@@ -107,9 +102,7 @@ def train(
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True