* allow plugins to return their own dataset * add post_trainer_create and wire up * add hook check * address PR feedback: * remove annotation causing circular import
125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
"""CLI to run training on a model."""
|
|
|
|
import gc
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import fire
|
|
from accelerate import Accelerator
|
|
from dotenv import load_dotenv
|
|
from transformers.hf_argparser import HfArgumentParser
|
|
|
|
from axolotl.cli.args import TrainerCliArgs
|
|
from axolotl.cli.art import print_axolotl_text_art
|
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
|
from axolotl.cli.config import load_cfg
|
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
|
from axolotl.integrations.base import PluginManager
|
|
from axolotl.train import train
|
|
from axolotl.utils import patch_optimized_env
|
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|
"""
|
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
|
manager's `post_train_unload` once training completes.
|
|
|
|
Args:
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
cli_args: Training-specific CLI arguments.
|
|
"""
|
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
|
patch_optimized_env()
|
|
|
|
print_axolotl_text_art()
|
|
check_accelerate_default_config()
|
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
|
check_user_token()
|
|
|
|
plugin_manager = PluginManager.get_instance()
|
|
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
|
if not dataset_meta:
|
|
if cfg.rl:
|
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
|
else:
|
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
|
|
del model, tokenizer, trainer
|
|
|
|
gc.collect()
|
|
|
|
plugin_manager = PluginManager.get_instance()
|
|
plugin_manager.post_train_unload(cfg)
|
|
|
|
|
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
"""
|
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
|
|
|
Args:
|
|
config: Path to `axolotl` config YAML file.
|
|
kwargs: Additional keyword arguments to override config file values.
|
|
"""
|
|
# pylint: disable=duplicate-code
|
|
parsed_cfg = load_cfg(config, **kwargs)
|
|
parser = HfArgumentParser(TrainerCliArgs)
|
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
|
return_remaining_strings=True
|
|
)
|
|
|
|
if parsed_cfg.use_ray:
|
|
from ray.train import RunConfig, ScalingConfig
|
|
from ray.train.torch import TorchTrainer
|
|
|
|
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
|
|
trainer = TorchTrainer(
|
|
ray_train_func,
|
|
train_loop_config=train_loop_config,
|
|
scaling_config=ScalingConfig(
|
|
num_workers=parsed_cfg.ray_num_workers,
|
|
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
|
|
use_gpu=True,
|
|
),
|
|
run_config=RunConfig(
|
|
name=parsed_cfg.ray_run_name,
|
|
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
|
|
),
|
|
)
|
|
return trainer.fit()
|
|
return do_train(parsed_cfg, parsed_cli_args)
|
|
|
|
|
|
def ray_train_func(kwargs: dict):
|
|
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
|
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
|
cfg = DictDefault(kwargs["cfg"])
|
|
normalize_config(cfg)
|
|
|
|
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
|
resolve_dtype(cfg)
|
|
|
|
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
|
if cfg.deepspeed:
|
|
cfg.deepspeed = cfg.deepspeed.to_dict()
|
|
|
|
# initialize accelerator before model instantiation
|
|
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
|
|
|
kwargs["cfg"] = cfg
|
|
|
|
do_train(**kwargs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
load_dotenv()
|
|
fire.Fire(do_cli)
|