Ray Train Axolotl Integration (#2251)
* current not clean working version move torch trainer to do_cli update code with config changes and clean up edit config cleanup add run name to trainer * address comments * use axolotl train in multigpu tests and add ray tests for multi-gpu * accelerate uses underscores for main_process_port arg * chore: lint * fix order of accelerate args * include ray train in docker images * current not clean working version move torch trainer to do_cli update code with config changes and clean up edit config cleanup add run name to trainer * address comments * use axolotl train in multigpu tests and add ray tests for multi-gpu * accelerate uses underscores for main_process_port arg * chore: lint * fix order of accelerate args * include ray train in docker images * fix bf16 resolution behavior * move dtype logic * x Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * rename Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * add to sidebar Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * Apply suggestions from code review Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com> * Update docs/ray-integration.qmd Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com> * pre-commit fixes Signed-off-by: SumanthRH <sumanthrh@anyscale.com> * use output_dir instead of hardcoded saves path Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * bugfix storage dir * change type\ for resources_per_worker --------- Signed-off-by: SumanthRH <sumanthrh@anyscale.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: SumanthRH <sumanthrh@anyscale.com> Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
@@ -56,33 +57,10 @@ def choose_device(cfg):
|
||||
cfg.device_map = None
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
if cfg.eval_batch_size is None:
|
||||
cfg.eval_batch_size = cfg.micro_batch_size
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
||||
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
||||
"sacrebleu",
|
||||
"comet",
|
||||
"ter",
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if cfg.bf16 == "auto":
|
||||
def resolve_dtype(cfg):
|
||||
if (
|
||||
cfg.bf16 == "auto" and not cfg.use_ray
|
||||
): # if we use ray we want to defer this check to the worker node
|
||||
if is_torch_bf16_gpu_available():
|
||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||
cfg.bf16 = True
|
||||
@@ -110,6 +88,43 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
|
||||
def normalize_config(cfg):
|
||||
# setup some derived config / hyperparams
|
||||
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
||||
cfg.batch_size // cfg.micro_batch_size
|
||||
)
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
if cfg.eval_batch_size is None:
|
||||
cfg.eval_batch_size = cfg.micro_batch_size
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
||||
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
||||
"sacrebleu",
|
||||
"comet",
|
||||
"ter",
|
||||
"chrf",
|
||||
]
|
||||
choose_device(cfg)
|
||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||
if cfg.ddp:
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if not cfg.use_ray:
|
||||
# delay resolving dtype until on worker node when launching with ray
|
||||
resolve_dtype(cfg)
|
||||
|
||||
if cfg.deepspeed:
|
||||
if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):
|
||||
ds_config_path = cfg.deepspeed
|
||||
with open(ds_config_path, encoding="utf-8") as f:
|
||||
cfg.deepspeed = json.load(f)
|
||||
|
||||
if cfg.saves_per_epoch:
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
|
||||
Reference in New Issue
Block a user