Ray Train Axolotl Integration (#2251)

* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* fix bf16 resolution behavior

* move dtype logic

* x

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* rename

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* add to sidebar

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* Apply suggestions from code review

Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>

* Update docs/ray-integration.qmd

Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>

* pre-commit fixes

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* use output_dir instead of hardcoded saves path

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* bugfix storage dir

* change type\ for resources_per_worker

---------

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: SumanthRH <sumanthrh@anyscale.com>
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
Eric Tang
2025-01-28 21:10:19 -08:00
committed by GitHub
parent 54dd7abfc1
commit 268543a3be
16 changed files with 492 additions and 100 deletions

View File

@@ -38,6 +38,7 @@ website:
- docs/multi-node.qmd - docs/multi-node.qmd
- docs/unsloth.qmd - docs/unsloth.qmd
- docs/amd_hpc.qmd - docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats" - section: "Dataset Formats"
contents: docs/dataset-formats/* contents: docs/dataset-formats/*
- section: "Reference" - section: "Reference"

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh RUN python scripts/unsloth_install.py | sh

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

93
docs/ray-integration.qmd Normal file
View File

@@ -0,0 +1,93 @@
---
title: Ray Train integration
description: How to use Axolotl with Ray Train
---
Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.
With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
## Sanity check
To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:
```bash
ray status
```
The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:
```
Node status
---------------------------------------------------------------
Active:
1 head
Idle:
2 4xL40S:48CPU-384GB
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/96.0 CPU
0.0/8.0 GPU
0B/800.00GiB memory
0B/229.57GiB object_store_memory
Demands:
(no resource demands)
```
You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).
## Configuring training with Ray Train
You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do
```yaml
resources_per_worker:
accelerator_type:L40S: 0.001
```
## Launching training
You can simply run the following command on the head node:
```bash
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
```
This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.
You can also monitor training progress on the Ray dashboard.
Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:
![Ray dashboard](./images/ray-cluster-dashboard.png)

View File

@@ -0,0 +1,79 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
use_ray: true
ray_num_workers: 4

View File

@@ -150,5 +150,8 @@ setup(
"lomo-optim==0.1.1", "lomo-optim==0.1.1",
"torch-optimi==0.2.1", "torch-optimi==0.2.1",
], ],
"ray": [
"ray[train]",
],
}, },
) )

View File

@@ -25,6 +25,8 @@ class TrainerCliArgs:
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass @dataclass

View File

@@ -67,8 +67,23 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf() set_pytorch_cuda_alloc_conf()
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config: if config:
base_cmd.append(config) base_cmd.append(config)
cmd = build_command(base_cmd, kwargs) cmd = build_command(base_cmd, kwargs)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from typing import Union from typing import Union
import fire import fire
from accelerate import Accelerator
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
@@ -15,6 +16,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -63,7 +65,47 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
return_remaining_strings=True return_remaining_strings=True
) )
do_train(parsed_cfg, parsed_cli_args) if parsed_cfg.use_ray:
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
trainer = TorchTrainer(
ray_train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=parsed_cfg.ray_num_workers,
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
use_gpu=True,
),
run_config=RunConfig(
name=parsed_cfg.ray_run_name,
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
),
)
return trainer.fit()
return do_train(parsed_cfg, parsed_cli_args)
def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
kwargs["cfg"] = cfg
do_train(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -141,7 +141,9 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir))) model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref): def terminate_handler(_, __, model_weakref):
if model_weakref() is not None: if model_weakref() is not None:

View File

@@ -1,4 +1,5 @@
"""Module for working with config dicts""" """Module for working with config dicts"""
import json
import logging import logging
import os import os
from typing import Optional from typing import Optional
@@ -56,33 +57,10 @@ def choose_device(cfg):
cfg.device_map = None cfg.device_map = None
def normalize_config(cfg): def resolve_dtype(cfg):
# setup some derived config / hyperparams if (
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( cfg.bf16 == "auto" and not cfg.use_ray
cfg.batch_size // cfg.micro_batch_size ): # if we use ray we want to defer this check to the worker node
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.bf16 == "auto":
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.") LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True cfg.bf16 = True
@@ -110,6 +88,43 @@ def normalize_config(cfg):
else: else:
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray
resolve_dtype(cfg)
if cfg.deepspeed:
if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):
ds_config_path = cfg.deepspeed
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.saves_per_epoch: if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step if save_steps < 1.0: # prevent saves on every step

View File

@@ -607,6 +607,30 @@ class GradioConfig(BaseModel):
gradio_temperature: Optional[float] = None gradio_temperature: Optional[float] = None
class RayConfig(BaseModel):
"""Ray launcher configuration subset"""
use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field(
default=None,
metadata={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
metadata={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
metadata={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
# pylint: disable=too-many-public-methods,too-many-ancestors # pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig( class AxolotlInputConfig(
ModelInputConfig, ModelInputConfig,
@@ -619,6 +643,7 @@ class AxolotlInputConfig(
CometConfig, CometConfig,
LISAConfig, LISAConfig,
GradioConfig, GradioConfig,
RayConfig,
RemappedParameters, RemappedParameters,
DeprecatedParameters, DeprecatedParameters,
BaseModel, BaseModel,

View File

@@ -74,15 +74,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -139,15 +137,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -214,15 +210,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -293,15 +287,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -367,15 +359,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -439,15 +429,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -520,15 +508,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -605,15 +591,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -680,15 +664,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )
@@ -755,15 +737,13 @@ class TestMultiGPULlama:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )

View File

@@ -86,14 +86,12 @@ class TestMultiGPUQwen2:
execute_subprocess_async( execute_subprocess_async(
[ [
"accelerate", "axolotl",
"launch", "train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes", "--num-processes",
"2", "2",
"--main_process_port", "--main-process-port",
f"{get_torch_dist_unique_port()}", f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
] ]
) )

View File

@@ -0,0 +1,137 @@
"""
E2E tests for multigpu post-training use Ray Train
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPURay:
"""
Test cases for AnyScale Ray post training
"""
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)