deepspeed fix (#2820)

This commit is contained in:
Dan Saunders
2025-06-23 09:07:57 -04:00
committed by GitHub
parent 0494359c6c
commit 1d8f500709
2 changed files with 17 additions and 4 deletions

View File

@@ -46,16 +46,23 @@ def get_current_device() -> int:
return 0
def init_distributed_state():
global distributed_state # pylint: disable=global-statement
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
def get_distributed_state() -> PartialState | None:
return distributed_state
def is_distributed() -> bool:
"""Check if distributed training is initialized."""
global distributed_state # pylint: disable=global-statement
init_distributed_state()
if distributed_state is None:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return False
return distributed_state.use_distributed and distributed_state.initialized

View File

@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -537,6 +537,12 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
init_distributed_state()
# If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)