From 1d8f500709e63637e89537653a72b2bdf7de3ee8 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 23 Jun 2025 09:07:57 -0400 Subject: [PATCH] deepspeed fix (#2820) --- src/axolotl/utils/distributed.py | 13 ++++++++++--- src/axolotl/utils/trainer.py | 8 +++++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index b509ad0ca..2192e7b9d 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -46,16 +46,23 @@ def get_current_device() -> int: return 0 +def init_distributed_state(): + global distributed_state # pylint: disable=global-statement + if distributed_state is None: + timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) + distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + + def get_distributed_state() -> PartialState | None: return distributed_state def is_distributed() -> bool: """Check if distributed training is initialized.""" - global distributed_state # pylint: disable=global-statement + init_distributed_state() + if distributed_state is None: - timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) - distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + return False return distributed_state.use_distributed and distributed_state.initialized diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e996cd62b..633dffde5 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 -from axolotl.utils.distributed import reduce_and_broadcast +from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -537,6 +537,12 @@ def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + + # NOTE(djsaunde): The distribued state cannot be initialized prior to the + # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior + # to model load. + init_distributed_state() + # If we don't assign this, it doesn't actually get set in the accelerate weakref _ = HfTrainerDeepSpeedConfig(cfg.deepspeed)