Adding documentation and continuing cleanup (in progress)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
@@ -30,19 +31,19 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parser = transformers.HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.inference:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
elif parsed_cli_args.merge_lora:
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
elif parsed_cli_args.shard:
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
shard(cfg=parsed_cfg)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=parsed_cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,6 +21,7 @@ AXOLOTL_LOGO = """
|
||||
|
||||
|
||||
def print_dep_versions():
|
||||
"""Prints versions of various axolotl dependencies."""
|
||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||
max_len = max(len(pkg) for pkg in packages)
|
||||
if is_main_process():
|
||||
@@ -33,6 +34,7 @@ def print_dep_versions():
|
||||
|
||||
|
||||
def print_legacy_axolotl_text_art(suffix=None):
|
||||
"""Prints axolotl ASCII art and dependency versions."""
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
@@ -46,5 +48,6 @@ def print_legacy_axolotl_text_art(suffix=None):
|
||||
|
||||
|
||||
def print_axolotl_text_art():
|
||||
"""Prints axolotl ASCII art."""
|
||||
if is_main_process():
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
@@ -15,6 +15,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_accelerate_default_config():
|
||||
"""Logs at warning level if no accelerate config file is found."""
|
||||
if Path(config_args.default_yaml_config_file).exists():
|
||||
LOG.warning(
|
||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||
@@ -22,6 +23,14 @@ def check_accelerate_default_config():
|
||||
|
||||
|
||||
def check_user_token():
|
||||
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
|
||||
|
||||
Returns:
|
||||
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
|
||||
|
||||
Raises:
|
||||
LocalTokenNotFoundError: If HF user info can't be retrieved.
|
||||
"""
|
||||
# Skip check if HF_HUB_OFFLINE is set to True
|
||||
if os.getenv("HF_HUB_OFFLINE") == "1":
|
||||
LOG.info(
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
@@ -17,7 +18,7 @@ LOG = logging.getLogger(__name__)
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
@@ -61,7 +62,9 @@ def load_datasets(
|
||||
def load_rl_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
||||
cli_args: Union[
|
||||
PreprocessCliArgs, TrainerCliArgs
|
||||
], # pylint: disable=unused-argument
|
||||
) -> TrainDatasetMeta:
|
||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
||||
total_num_steps = int(
|
||||
|
||||
@@ -29,7 +29,7 @@ def do_evaluate(cfg, cli_args) -> None:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
evaluate(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -15,7 +15,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.common.cli import InferenceCliArgs, load_model_and_tokenizer
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
@@ -25,7 +25,13 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
def get_multi_line_input() -> str:
|
||||
"""
|
||||
Gets multi-line input from terminal.
|
||||
|
||||
Returns:
|
||||
Possibly multi-line, possibly empty stdin input as a string.
|
||||
"""
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
@@ -37,9 +43,18 @@ def get_multi_line_input() -> Optional[str]:
|
||||
def do_inference(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
cli_args: InferenceCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
"""
|
||||
Runs inference on the command line in a loop. User input is accepted, a chat template
|
||||
is (optionally) applied, and the model specified in the `axolotl` config is used to
|
||||
generate completions according to a default generation config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
@@ -121,11 +136,20 @@ def do_inference(
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
cli_args: InferenceCliArgs,
|
||||
):
|
||||
"""
|
||||
Runs inference in a Gradio interface. User input is accepted, a chat template is
|
||||
(optionally) applied, and the model specified in the `axolotl` config is used to
|
||||
generate completions according to a default generation config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
import gradio as gr
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
@@ -218,16 +242,25 @@ def do_inference_gradio(
|
||||
)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
|
||||
def do_cli(
|
||||
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Parses axolotl config, training-specific CLI args, and calls `do_inference` or
|
||||
`do_inference_gradio` as a subroutine.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -29,7 +29,14 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def preprocess(config: str, **kwargs):
|
||||
"""Preprocess datasets before training."""
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
@@ -46,8 +53,16 @@ def preprocess(config: str, **kwargs):
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def train(config: str, accelerate: bool, **kwargs):
|
||||
"""Train or fine-tune a model."""
|
||||
def train(config: str, accelerate: bool, **kwargs) -> None:
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
@@ -74,8 +89,16 @@ def train(config: str, accelerate: bool, **kwargs):
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
"""Evaluate a model."""
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if accelerate:
|
||||
@@ -97,46 +120,32 @@ def evaluate(config: str, accelerate: bool, **kwargs):
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
)
|
||||
@click.option(
|
||||
"--lora-model-dir",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="Directory containing LoRA model",
|
||||
)
|
||||
@click.option(
|
||||
"--base-model",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="Path to base model for non-LoRA models",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def inference(
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
lora_model_dir: Optional[str] = None,
|
||||
base_model: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run inference with a trained model."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
del kwargs["inference"] # interferes with inference.do_cli
|
||||
|
||||
if lora_model_dir:
|
||||
kwargs["lora_model_dir"] = lora_model_dir
|
||||
if base_model:
|
||||
kwargs["base_model"] = base_model
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
base_cmd.append("--gradio")
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
from axolotl.cli.inference import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -146,20 +155,18 @@ def inference(
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU operations",
|
||||
)
|
||||
@click.option(
|
||||
"--model-dir",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="Directory containing model weights to shard",
|
||||
)
|
||||
@click.option(
|
||||
"--save-dir",
|
||||
type=click.Path(path_type=str),
|
||||
help="Directory to save sharded weights",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def shard(config: str, accelerate: bool, **kwargs):
|
||||
"""Shard model weights."""
|
||||
def shard(config: str, accelerate: bool, **kwargs) -> None:
|
||||
"""
|
||||
Shard model weights into chunks.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if accelerate:
|
||||
@@ -181,18 +188,18 @@ def shard(config: str, accelerate: bool, **kwargs):
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
)
|
||||
@click.option(
|
||||
"--model-dir",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="Directory containing sharded weights",
|
||||
)
|
||||
@click.option(
|
||||
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
|
||||
"""Merge sharded FSDP model weights."""
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if accelerate:
|
||||
@@ -214,27 +221,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--lora-model-dir",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="Directory containing the LoRA model to merge",
|
||||
)
|
||||
@click.option(
|
||||
"--output-dir",
|
||||
type=click.Path(path_type=str),
|
||||
help="Directory to save the merged model",
|
||||
)
|
||||
def merge_lora(
|
||||
config: str,
|
||||
lora_model_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""Merge a trained LoRA into a base model"""
|
||||
kwargs = {}
|
||||
if lora_model_dir:
|
||||
kwargs["lora_model_dir"] = lora_model_dir
|
||||
if output_dir:
|
||||
kwargs["output_dir"] = output_dir
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
from axolotl.cli.merge_lora import do_cli
|
||||
|
||||
@@ -244,13 +243,17 @@ def merge_lora(
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
Available directories:
|
||||
- examples: Example configuration files
|
||||
- deepspeed_configs: DeepSpeed configuration files
|
||||
|
||||
Args:
|
||||
directory: One of `examples`, `deepspeed_configs`.
|
||||
dest: Optional destination directory.
|
||||
"""
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
@@ -13,18 +13,23 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.merge_lora")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_merge_lora(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
|
||||
along with the LoRA adapters to combine them into a single base model.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
model = model.merge_and_unload(progressbar=True)
|
||||
try:
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
@@ -33,7 +38,7 @@ def do_merge_lora(
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
@@ -42,10 +47,21 @@ def do_merge_lora(
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, training-specific CLI args, and calls `do_merge_lora` as a subroutine.
|
||||
Note that various config values will be overwritten to allow the LoRA merge logic to work
|
||||
as expected (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
|
||||
Raises:
|
||||
ValueError: If target directory for LoRA merged model does not exist.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parser = transformers.HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
@@ -76,7 +92,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
parsed_cfg.fsdp = None
|
||||
parsed_cfg.fsdp_config = None
|
||||
|
||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
do_merge_lora(cfg=parsed_cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -18,24 +18,25 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
"""
|
||||
Preprocesses dataset specified in axolotl config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Preprocessing-specific CLI arguments.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
if not cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
@@ -43,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
with disable_datasets_caching():
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
if parsed_cli_args.download:
|
||||
model_name = parsed_cfg.base_model
|
||||
if cli_args.download:
|
||||
model_name = cfg.base_model
|
||||
with warnings.catch_warnings():
|
||||
# there are a bunch of useless UserWarnings about
|
||||
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
|
||||
@@ -69,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, preprocessing-specific CLI args, and calls `do_preprocess` as a subroutine.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_preprocess(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,39 +5,27 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.common.cli import load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def shard(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
def shard(*, cfg: DictDefault):
|
||||
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
LOG.debug("Re-saving model w/ sharding")
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.shard = True
|
||||
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
shard(cfg=parsed_cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,21 +15,21 @@ from axolotl.cli.datasets import load_datasets, load_rl_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
Trains transformer model by first loading the dataset(s) specified in the `axolotl`
|
||||
config, and then calling `axolotl.train.train` as a subroutine. Also runs the plugin
|
||||
manager's `post_train_unload` once training completes.
|
||||
|
||||
|
||||
def do_train(cfg, cli_args) -> None:
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
@@ -39,7 +39,7 @@ def do_train(cfg, cli_args) -> None:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
@@ -48,6 +48,24 @@ def do_train(cfg, cli_args) -> None:
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, training-specific CLI args, and calls `do_train` as a subroutine.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
shared module for cli specific things
|
||||
"""
|
||||
"""Shared module for CLI specific utilities."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger("axolotl.common.cli")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
@@ -30,14 +28,11 @@ class PreprocessCliArgs:
|
||||
|
||||
@dataclass
|
||||
class TrainerCliArgs:
|
||||
"""
|
||||
dataclass representing the various non-training arguments
|
||||
"""
|
||||
"""Dataclass with CLI arguments for `axolotl train` command."""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
@@ -45,25 +40,40 @@ class TrainerCliArgs:
|
||||
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""
|
||||
dataclass representing the various evaluation arguments
|
||||
"""
|
||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl inference` command."""
|
||||
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
inference: bool = False,
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||
"""
|
||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
`transformers` model and tokenizer.
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
inference = getattr(cli_args, "inference", False)
|
||||
LOG.info("loading model...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
@@ -62,16 +61,13 @@ def evaluate_dataset(
|
||||
return metrics
|
||||
|
||||
|
||||
def evaluate(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
) -> Dict[str, float]:
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
cli_args: Command line arguments
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets
|
||||
cfg: Config dictionary.
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
@@ -102,9 +98,7 @@ def evaluate(
|
||||
|
||||
# Load model
|
||||
LOG.debug("loading model for evaluation...")
|
||||
model, _ = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
model, _ = load_model(cfg, tokenizer, processor=processor)
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
|
||||
@@ -19,7 +19,6 @@ from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
@@ -39,14 +38,12 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger("axolotl.train")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainDatasetMeta:
|
||||
"""
|
||||
dataclass to capture the dataset specific options for training
|
||||
"""
|
||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
@@ -54,7 +51,7 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
@@ -93,9 +90,7 @@ def train(
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
@@ -107,9 +102,7 @@ def train(
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(
|
||||
cfg, tokenizer, inference=cli_args.inference, reference_model=True
|
||||
)
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user