When using Ray use prepare for dataloader fixes (#3198)

* make sure to use ray prepare for dataloader fixes

* ray tests use 2.7.0+

* don't call init_distributed w ray and deepspeed

* handle dict deepspeed config

* better handling of dict deepspeed config

* use json.dumps

* guard to_dict

* wrap import for optional ray
This commit is contained in:
Wing Lian
2025-10-08 10:43:41 -04:00
committed by GitHub
parent 4c3488cc9f
commit d0e9c3c1c5
4 changed files with 29 additions and 5 deletions

View File

@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
resolve_dtype(cfg) resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed: if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
cfg.deepspeed = cfg.deepspeed.to_dict() cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation # initialize accelerator before model instantiation

View File

@@ -525,6 +525,17 @@ def setup_model_and_trainer(
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer) plugin_manager.post_trainer_create(cfg, trainer)
if cfg.use_ray:
try:
import ray.train.huggingface.transformers
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
except ImportError:
LOG.warning(
"The Ray integration with Hugging Face Transformers is not available. "
"To use Ray, install the 'ray[train]' package."
)
return ( return (
trainer, trainer,
model, model,

View File

@@ -6,6 +6,7 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from tempfile import NamedTemporaryFile
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
@@ -15,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -540,6 +542,13 @@ def setup_deepspeed_env(cfg, stage=None):
) )
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
if isinstance(cfg.deepspeed, DictDefault):
with NamedTemporaryFile(
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
) as temp_file:
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
temp_file.close()
cfg.deepspeed = str(temp_file.name)
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str( os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps cfg.gradient_accumulation_steps
@@ -562,6 +571,7 @@ def setup_deepspeed_env(cfg, stage=None):
if ( if (
int(os.environ.get("WORLD_SIZE", "1")) == 1 int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
and cfg.use_ray is not True
): ):
os.environ["WORLD_SIZE"] = "1" # force it in case not set os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set os.environ["LOCAL_RANK"] = "0" # force it in case not set
@@ -638,11 +648,15 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
stage = None stage = None
deepspeed_config = None
# check if the cfg.deepspeed is a file # check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed): if isinstance(cfg.deepspeed, DictDefault):
deepspeed_config = cfg.deepspeed
elif os.path.isfile(cfg.deepspeed):
# parse with json # parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin: with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin) deepspeed_config = json.load(fin)
if deepspeed_config:
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage) setup_deepspeed_env(cfg, stage=stage)

View File

@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import ( from tests.e2e.utils import (
check_tensorboard, check_tensorboard,
require_torch_2_7_0, require_torch_2_7_0,
require_torch_lt_2_6_0,
) )
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -24,7 +23,7 @@ class TestMultiGPURay:
Test cases for AnyScale Ray post training Test cases for AnyScale Ray post training
""" """
@require_torch_lt_2_6_0 @require_torch_2_7_0
def test_lora_ddp(self, temp_dir): def test_lora_ddp(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -83,7 +82,7 @@ class TestMultiGPURay:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
) )
@require_torch_lt_2_6_0 @require_torch_2_7_0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"gradient_accumulation_steps", "gradient_accumulation_steps",
[1, 2], [1, 2],