diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2332717e7..6b3bfbd57 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -99,7 +99,7 @@ def ray_train_func(kwargs: dict): resolve_dtype(cfg) # ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict - if cfg.deepspeed: + if cfg.deepspeed and hasattr(cfg.deepspeed, "to_dict"): cfg.deepspeed = cfg.deepspeed.to_dict() # initialize accelerator before model instantiation diff --git a/src/axolotl/train.py b/src/axolotl/train.py index da7b63121..441c50871 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -525,6 +525,17 @@ def setup_model_and_trainer( plugin_manager = PluginManager.get_instance() plugin_manager.post_trainer_create(cfg, trainer) + if cfg.use_ray: + try: + import ray.train.huggingface.transformers + + trainer = ray.train.huggingface.transformers.prepare_trainer(trainer) + except ImportError: + LOG.warning( + "The Ray integration with Hugging Face Transformers is not available. " + "To use Ray, install the 'ray[train]' package." + ) + return ( trainer, model, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 56fbe34c0..c7fa0a647 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -6,6 +6,7 @@ import os import random from contextlib import contextmanager from functools import partial +from tempfile import NamedTemporaryFile from typing import List, Optional import numpy as np @@ -15,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available +from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.logging import get_logger @@ -540,6 +542,13 @@ def setup_deepspeed_env(cfg, stage=None): ) os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + if isinstance(cfg.deepspeed, DictDefault): + with NamedTemporaryFile( + mode="w", delete=False, suffix=".json", prefix="deepspeed_config_" + ) as temp_file: + temp_file.write(json.dumps(cfg.deepspeed.to_dict(), indent=4)) + temp_file.close() + cfg.deepspeed = str(temp_file.name) os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str( cfg.gradient_accumulation_steps @@ -562,6 +571,7 @@ def setup_deepspeed_env(cfg, stage=None): if ( int(os.environ.get("WORLD_SIZE", "1")) == 1 and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" + and cfg.use_ray is not True ): os.environ["WORLD_SIZE"] = "1" # force it in case not set os.environ["LOCAL_RANK"] = "0" # force it in case not set @@ -638,11 +648,15 @@ def prepare_optim_env(cfg): setup_fsdp_envs(cfg) elif cfg.deepspeed: stage = None + deepspeed_config = None # check if the cfg.deepspeed is a file - if os.path.isfile(cfg.deepspeed): + if isinstance(cfg.deepspeed, DictDefault): + deepspeed_config = cfg.deepspeed + elif os.path.isfile(cfg.deepspeed): # parse with json with open(cfg.deepspeed, "r", encoding="utf-8") as fin: deepspeed_config = json.load(fin) + if deepspeed_config: stage = deepspeed_config.get("zero_optimization", {}).get("stage", None) setup_deepspeed_env(cfg, stage=stage) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 7c6ea8a1f..df41b1444 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -13,7 +13,6 @@ from axolotl.utils.dict import DictDefault from tests.e2e.utils import ( check_tensorboard, require_torch_2_7_0, - require_torch_lt_2_6_0, ) AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent @@ -24,7 +23,7 @@ class TestMultiGPURay: Test cases for AnyScale Ray post training """ - @require_torch_lt_2_6_0 + @require_torch_2_7_0 def test_lora_ddp(self, temp_dir): cfg = DictDefault( { @@ -83,7 +82,7 @@ class TestMultiGPURay: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) - @require_torch_lt_2_6_0 + @require_torch_2_7_0 @pytest.mark.parametrize( "gradient_accumulation_steps", [1, 2],