Ray Train Axolotl Integration (#2251)
* current not clean working version move torch trainer to do_cli update code with config changes and clean up edit config cleanup add run name to trainer * address comments * use axolotl train in multigpu tests and add ray tests for multi-gpu * accelerate uses underscores for main_process_port arg * chore: lint * fix order of accelerate args * include ray train in docker images * current not clean working version move torch trainer to do_cli update code with config changes and clean up edit config cleanup add run name to trainer * address comments * use axolotl train in multigpu tests and add ray tests for multi-gpu * accelerate uses underscores for main_process_port arg * chore: lint * fix order of accelerate args * include ray train in docker images * fix bf16 resolution behavior * move dtype logic * x Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * rename Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * add to sidebar Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * Apply suggestions from code review Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com> * Update docs/ray-integration.qmd Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com> * pre-commit fixes Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * use output_dir instead of hardcoded saves path Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * bugfix storage dir * change type\ for resources_per_worker --------- Signed-off-by: SumanthRH <sumanthrh@anyscale.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: SumanthRH <sumanthrh@anyscale.com> Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -25,6 +25,8 @@ class TrainerCliArgs:
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -67,8 +67,23 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
@@ -15,6 +16,7 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -63,7 +65,47 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_train(parsed_cfg, parsed_cli_args)
|
||||
if parsed_cfg.use_ray:
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
from ray.train.torch import TorchTrainer
|
||||
|
||||
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
|
||||
trainer = TorchTrainer(
|
||||
ray_train_func,
|
||||
train_loop_config=train_loop_config,
|
||||
scaling_config=ScalingConfig(
|
||||
num_workers=parsed_cfg.ray_num_workers,
|
||||
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
|
||||
use_gpu=True,
|
||||
),
|
||||
run_config=RunConfig(
|
||||
name=parsed_cfg.ray_run_name,
|
||||
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
|
||||
),
|
||||
)
|
||||
return trainer.fit()
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def ray_train_func(kwargs: dict):
|
||||
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
|
||||
# also renormalize the config now that TorchTrainer has spawned distributed workers
|
||||
cfg = DictDefault(kwargs["cfg"])
|
||||
normalize_config(cfg)
|
||||
|
||||
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||
|
||||
kwargs["cfg"] = cfg
|
||||
|
||||
do_train(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -141,7 +141,9 @@ def train(
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if cfg.local_rank == 0:
|
||||
if (
|
||||
cfg.local_rank == 0 and not cfg.use_ray
|
||||
): # ray workers don't have access to this signal
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
@@ -56,33 +57,10 @@ def choose_device(cfg):
|
||||
cfg.device_map = None
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
if cfg.eval_batch_size is None:
|
||||
cfg.eval_batch_size = cfg.micro_batch_size
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
||||
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
||||
"sacrebleu",
|
||||
"comet",
|
||||
"ter",
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if cfg.bf16 == "auto":
|
||||
def resolve_dtype(cfg):
|
||||
if (
|
||||
cfg.bf16 == "auto" and not cfg.use_ray
|
||||
): # if we use ray we want to defer this check to the worker node
|
||||
if is_torch_bf16_gpu_available():
|
||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||
cfg.bf16 = True
|
||||
@@ -110,6 +88,43 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
if cfg.eval_batch_size is None:
|
||||
cfg.eval_batch_size = cfg.micro_batch_size
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
||||
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
||||
"sacrebleu",
|
||||
"comet",
|
||||
"ter",
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if not cfg.use_ray:
|
||||
# delay resolving dtype until on worker node when launching with ray
|
||||
resolve_dtype(cfg)
|
||||
|
||||
if cfg.deepspeed:
|
||||
if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):
|
||||
ds_config_path = cfg.deepspeed
|
||||
with open(ds_config_path, encoding="utf-8") as f:
|
||||
cfg.deepspeed = json.load(f)
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
|
||||
@@ -607,6 +607,30 @@ class GradioConfig(BaseModel):
|
||||
gradio_temperature: Optional[float] = None
|
||||
|
||||
|
||||
class RayConfig(BaseModel):
|
||||
"""Ray launcher configuration subset"""
|
||||
|
||||
use_ray: bool = Field(default=False)
|
||||
ray_run_name: Optional[str] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||
},
|
||||
)
|
||||
ray_num_workers: int = Field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||
},
|
||||
)
|
||||
resources_per_worker: dict = Field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods,too-many-ancestors
|
||||
class AxolotlInputConfig(
|
||||
ModelInputConfig,
|
||||
@@ -619,6 +643,7 @@ class AxolotlInputConfig(
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RayConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
|
||||
Reference in New Issue
Block a user