When using Ray use prepare for dataloader fixes (#3198)
* make sure to use ray prepare for dataloader fixes * ray tests use 2.7.0+ * don't call init_distributed w ray and deepspeed * handle dict deepspeed config * better handling of dict deepspeed config * use json.dumps * guard to_dict * wrap import for optional ray
This commit is contained in:
@@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict):
|
||||
resolve_dtype(cfg)
|
||||
|
||||
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
|
||||
if cfg.deepspeed:
|
||||
if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"):
|
||||
cfg.deepspeed = cfg.deepspeed.to_dict()
|
||||
|
||||
# initialize accelerator before model instantiation
|
||||
|
||||
@@ -525,6 +525,17 @@ def setup_model_and_trainer(
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
if cfg.use_ray:
|
||||
try:
|
||||
import ray.train.huggingface.transformers
|
||||
|
||||
trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
|
||||
except ImportError:
|
||||
LOG.warning(
|
||||
"The Ray integration with Hugging Face Transformers is not available. "
|
||||
"To use Ray, install the 'ray[train]' package."
|
||||
)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -15,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -540,6 +542,13 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
)
|
||||
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
if isinstance(cfg.deepspeed, DictDefault):
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".json", prefix="deepspeed_config_"
|
||||
) as temp_file:
|
||||
temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4))
|
||||
temp_file.close()
|
||||
cfg.deepspeed = str(temp_file.name)
|
||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
|
||||
cfg.gradient_accumulation_steps
|
||||
@@ -562,6 +571,7 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
if (
|
||||
int(os.environ.get("WORLD_SIZE", "1")) == 1
|
||||
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
|
||||
and cfg.use_ray is not True
|
||||
):
|
||||
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||
@@ -638,11 +648,15 @@ def prepare_optim_env(cfg):
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
stage = None
|
||||
deepspeed_config = None
|
||||
# check if the cfg.deepspeed is a file
|
||||
if os.path.isfile(cfg.deepspeed):
|
||||
if isinstance(cfg.deepspeed, DictDefault):
|
||||
deepspeed_config = cfg.deepspeed
|
||||
elif os.path.isfile(cfg.deepspeed):
|
||||
# parse with json
|
||||
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||
deepspeed_config = json.load(fin)
|
||||
if deepspeed_config:
|
||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||
setup_deepspeed_env(cfg, stage=stage)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
@@ -24,7 +23,7 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -83,7 +82,7 @@ class TestMultiGPURay:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
|
||||
Reference in New Issue
Block a user