Adding documentation and continuing cleanup (in progress)

This commit is contained in:
Dan Saunders
2025-01-07 20:16:39 +00:00
parent 324c533adb
commit 773c3b51cd
14 changed files with 283 additions and 192 deletions

View File

@@ -1,4 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora."""
import logging import logging
from pathlib import Path from pathlib import Path
@@ -30,19 +31,19 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.inference: if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora: elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg)
elif parsed_cli_args.shard: elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, dataset_meta=dataset_meta)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -21,6 +21,7 @@ AXOLOTL_LOGO = """
def print_dep_versions(): def print_dep_versions():
"""Prints versions of various axolotl dependencies."""
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages) max_len = max(len(pkg) for pkg in packages)
if is_main_process(): if is_main_process():
@@ -33,6 +34,7 @@ def print_dep_versions():
def print_legacy_axolotl_text_art(suffix=None): def print_legacy_axolotl_text_art(suffix=None):
"""Prints axolotl ASCII art and dependency versions."""
font = "nancyj" font = "nancyj"
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
@@ -46,5 +48,6 @@ def print_legacy_axolotl_text_art(suffix=None):
def print_axolotl_text_art(): def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process(): if is_main_process():
print(AXOLOTL_LOGO) print(AXOLOTL_LOGO)

View File

@@ -15,6 +15,7 @@ LOG = logging.getLogger(__name__)
def check_accelerate_default_config(): def check_accelerate_default_config():
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists(): if Path(config_args.default_yaml_config_file).exists():
LOG.warning( LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
@@ -22,6 +23,14 @@ def check_accelerate_default_config():
def check_user_token(): def check_user_token():
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
Raises:
LocalTokenNotFoundError: If HF user info can't be retrieved.
"""
# Skip check if HF_HUB_OFFLINE is set to True # Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1": if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info( LOG.info(

View File

@@ -2,8 +2,9 @@
import logging import logging
import math import math
import random import random
from typing import Union
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_dpo_datasets
@@ -17,7 +18,7 @@ LOG = logging.getLogger(__name__)
def load_datasets( def load_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
@@ -61,7 +62,9 @@ def load_datasets(
def load_rl_datasets( def load_rl_datasets(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument cli_args: Union[
PreprocessCliArgs, TrainerCliArgs
], # pylint: disable=unused-argument
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int( total_num_steps = int(

View File

@@ -29,7 +29,7 @@ def do_evaluate(cfg, cli_args) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -5,7 +5,7 @@ import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Optional, Union from typing import Union
import fire import fire
import torch import torch
@@ -15,7 +15,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer
from axolotl.utils.chat_templates import ( from axolotl.utils.chat_templates import (
get_chat_template, get_chat_template,
get_chat_template_from_config, get_chat_template_from_config,
@@ -25,7 +25,13 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def get_multi_line_input() -> Optional[str]: def get_multi_line_input() -> str:
"""
Gets multi-line input from terminal.
Returns:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ") print("Give me an instruction (Ctrl + D to submit): ")
instruction = "" instruction = ""
for line in sys.stdin: for line in sys.stdin:
@@ -37,9 +43,18 @@ def get_multi_line_input() -> Optional[str]:
def do_inference( def do_inference(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: InferenceCliArgs,
): ):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) """
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter prompter = cli_args.prompter
prompter_module = None prompter_module = None
@@ -121,11 +136,20 @@ def do_inference(
def do_inference_gradio( def do_inference_gradio(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: InferenceCliArgs,
): ):
"""
Runs inference in a Gradio interface. User input is accepted, a chat template is
(optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
import gradio as gr import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter prompter = cli_args.prompter
prompter_module = None prompter_module = None
@@ -218,16 +242,25 @@ def do_inference_gradio(
) )
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs): def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None:
"""
Parses axolotl config, training-specific CLI args, and calls `do_inference` or
`do_inference_gradio` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
parsed_cli_args.inference = True
if gradio: if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -29,7 +29,14 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs) @add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs): def preprocess(config: str, **kwargs):
"""Preprocess datasets before training.""" """
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli from axolotl.cli.preprocess import do_cli
@@ -46,8 +53,16 @@ def preprocess(config: str, **kwargs):
) )
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs): def train(config: str, accelerate: bool, **kwargs) -> None:
"""Train or fine-tune a model.""" """
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage
@@ -74,8 +89,16 @@ def train(config: str, accelerate: bool, **kwargs):
) )
@add_options_from_dataclass(EvaluateCliArgs) @add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def evaluate(config: str, accelerate: bool, **kwargs): def evaluate(config: str, accelerate: bool, **kwargs) -> None:
"""Evaluate a model.""" """
Evaluate a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate: if accelerate:
@@ -97,46 +120,32 @@ def evaluate(config: str, accelerate: bool, **kwargs):
default=False, default=False,
help="Use accelerate launch for multi-GPU inference", help="Use accelerate launch for multi-GPU inference",
) )
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface") @click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def inference( def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
config: str, """
accelerate: bool, Run inference with a trained model.
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config: if config:
base_cmd.append(config) base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs) cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603 subprocess.run(cmd, check=True) # nosec B603
else: else:
from axolotl.cli.inference import do_cli from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, gradio=gradio, **kwargs)
@cli.command() @cli.command()
@@ -146,20 +155,18 @@ def inference(
default=False, default=False,
help="Use accelerate launch for multi-GPU operations", help="Use accelerate launch for multi-GPU operations",
) )
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs): def shard(config: str, accelerate: bool, **kwargs) -> None:
"""Shard model weights.""" """
Shard model weights into chunks.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate: if accelerate:
@@ -181,18 +188,18 @@ def shard(config: str, accelerate: bool, **kwargs):
default=True, default=True,
help="Use accelerate launch for weight merging", help="Use accelerate launch for weight merging",
) )
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
"""Merge sharded FSDP model weights.""" """
Merge sharded FSDP model weights.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate: if accelerate:
@@ -214,27 +221,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
@cli.command() @cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option( @add_options_from_dataclass(TrainerCliArgs)
"--lora-model-dir", @add_options_from_config(AxolotlInputConfig)
type=click.Path(exists=True, path_type=str), def merge_lora(config: str, **kwargs) -> None:
help="Directory containing the LoRA model to merge", """
) Merge trained LoRA adapters into a base model.
@click.option(
"--output-dir", Args:
type=click.Path(path_type=str), config: Path to `axolotl` config YAML file.
help="Directory to save the merged model", accelerate: Whether to use `accelerate` launcher.
) kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
def merge_lora( config options.
config: str, """
lora_model_dir: Optional[str] = None, kwargs = {k: v for k, v in kwargs.items() if v is not None}
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
from axolotl.cli.merge_lora import do_cli from axolotl.cli.merge_lora import do_cli
@@ -244,13 +243,17 @@ def merge_lora(
@cli.command() @cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory") @click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]): def fetch(directory: str, dest: Optional[str]) -> None:
""" """
Fetch example configs or other resources. Fetch example configs or other resources.
Available directories: Available directories:
- examples: Example configuration files - examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files - deepspeed_configs: DeepSpeed configuration files
Args:
directory: One of `examples`, `deepspeed_configs`.
dest: Optional destination directory.
""" """
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)

View File

@@ -13,18 +13,23 @@ from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.merge_lora") LOG = logging.getLogger(__name__)
def do_merge_lora( def do_merge_lora(*, cfg: DictDefault) -> None:
*, """
cfg: DictDefault, Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
cli_args: TrainerCliArgs, along with the LoRA adapters to combine them into a single base model.
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
print_axolotl_text_art()
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model") LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True) model = model.merge_and_unload(progressbar=True)
try: try:
model.to(dtype=cfg.torch_dtype) model.to(dtype=cfg.torch_dtype)
@@ -33,7 +38,7 @@ def do_merge_lora(
model.generation_config.do_sample = True model.generation_config.do_sample = True
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained( model.save_pretrained(
str(Path(cfg.output_dir) / "merged"), str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
@@ -42,10 +47,21 @@ def do_merge_lora(
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine.
Note that various config values will be overwritten to allow the LoRA merge logic to work
as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() parser = transformers.HfArgumentParser(TrainerCliArgs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
@@ -76,7 +92,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) do_merge_lora(cfg=parsed_cfg)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -18,24 +18,25 @@ from axolotl.cli.config import load_cfg
from axolotl.cli.datasets import load_datasets, load_rl_datasets from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
# pylint: disable=duplicate-code """
Preprocesses dataset specified in axolotl config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments.
"""
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path: if not cfg.dataset_prepared_path:
msg = ( msg = (
Fore.RED Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, " + "preprocess CLI called without dataset_prepared_path set, "
@@ -43,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
+ Fore.RESET + Fore.RESET
) )
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching(): with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": if cfg.rl: # and cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) load_datasets(cfg=cfg, cli_args=cli_args)
if parsed_cli_args.download: if cli_args.download:
model_name = parsed_cfg.base_model model_name = cfg.base_model
with warnings.catch_warnings(): with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about # there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
@@ -69,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.info( LOG.info(
Fore.GREEN Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`" + f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`"
+ Fore.RESET + Fore.RESET
) )
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_preprocess(parsed_cfg, parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
fire.Fire(do_cli) fire.Fire(do_cli)

View File

@@ -5,39 +5,27 @@ from pathlib import Path
from typing import Union from typing import Union
import fire import fire
import transformers
from dotenv import load_dotenv from dotenv import load_dotenv
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def shard( def shard(*, cfg: DictDefault):
*, model, _ = load_model_and_tokenizer(cfg=cfg)
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding") LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs)) shard(cfg=parsed_cfg)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -15,21 +15,21 @@ from axolotl.cli.datasets import load_datasets, load_rl_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
# pylint: disable=duplicate-code """
parsed_cfg = load_cfg(config, **kwargs) Trains transformer model by first loading the dataset(s) specified in the `axolotl`
parser = HfArgumentParser((TrainerCliArgs)) config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin
parsed_cli_args, _ = parser.parse_args_into_dataclasses( manager's `post_train_unload` once training completes.
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
Args:
def do_train(cfg, cli_args) -> None: cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
@@ -39,7 +39,7 @@ def do_train(cfg, cli_args) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model del model
@@ -48,6 +48,24 @@ def do_train(cfg, cli_args) -> None:
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
fire.Fire(do_cli) fire.Fire(do_cli)

View File

@@ -1,10 +1,10 @@
""" """Shared module for CLI specific utilities."""
shared module for cli specific things
"""
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Any, Optional, Tuple
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
configure_logging() configure_logging()
LOG = logging.getLogger("axolotl.common.cli") LOG = logging.getLogger(__name__)
@dataclass @dataclass
class PreprocessCliArgs: class PreprocessCliArgs:
""" """Dataclass with CLI arguments for `axolotl preprocess` command."""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
@@ -30,14 +28,11 @@ class PreprocessCliArgs:
@dataclass @dataclass
class TrainerCliArgs: class TrainerCliArgs:
""" """Dataclass with CLI arguments for `axolotl train` command."""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@@ -45,25 +40,40 @@ class TrainerCliArgs:
@dataclass @dataclass
class EvaluateCliArgs: class EvaluateCliArgs:
""" """Dataclass with CLI arguments for `axolotl evaluate` command."""
dataclass representing the various evaluation arguments
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0) debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, inference: bool = False,
): ) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...") LOG.info("loading model...")
inference = getattr(cli_args, "inference", False)
model, _ = load_model(cfg, tokenizer, inference=inference) model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer return model, tokenizer

View File

@@ -9,7 +9,6 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -62,16 +61,13 @@ def evaluate_dataset(
return metrics return metrics
def evaluate( def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets
Args: Args:
cfg: Configuration dictionary cfg: Config dictionary.
cli_args: Command line arguments dataset_meta: Dataset metadata containing training and evaluation datasets.
dataset_meta: Dataset metadata containing training and evaluation datasets
Returns: Returns:
Tuple containing: Tuple containing:
@@ -102,9 +98,7 @@ def evaluate(
# Load model # Load model
LOG.debug("loading model for evaluation...") LOG.debug("loading model for evaluation...")
model, _ = load_model( model, _ = load_model(cfg, tokenizer, processor=processor)
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer # Set up trainer
trainer = setup_trainer( trainer = setup_trainer(

View File

@@ -19,7 +19,6 @@ from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
@@ -39,14 +38,12 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = get_logger("axolotl.train") LOG = get_logger(__name__)
@dataclass @dataclass
class TrainDatasetMeta: class TrainDatasetMeta:
""" """Dataclass with fields for training and validation datasets and metadata."""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset train_dataset: Dataset
eval_dataset: Optional[Dataset] = None eval_dataset: Optional[Dataset] = None
@@ -54,7 +51,7 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
@@ -93,9 +90,7 @@ def train(
if cfg.adapter: if cfg.adapter:
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
model, peft_config = load_model( model, peft_config = load_model(cfg, tokenizer, processor=processor)
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
if model.generation_config is not None: if model.generation_config is not None:
model.generation_config.do_sample = True model.generation_config.do_sample = True
@@ -107,9 +102,7 @@ def train(
model_ref = None # explicit setting to None model_ref = None # explicit setting to None
else: else:
# load the model again for model_ref/baseline # load the model again for model_ref/baseline
model_ref, _ = load_model( model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True