deepspeed fix (#2820)
This commit is contained in:
@@ -46,16 +46,23 @@ def get_current_device() -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_state():
|
||||||
|
global distributed_state # pylint: disable=global-statement
|
||||||
|
if distributed_state is None:
|
||||||
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_state() -> PartialState | None:
|
def get_distributed_state() -> PartialState | None:
|
||||||
return distributed_state
|
return distributed_state
|
||||||
|
|
||||||
|
|
||||||
def is_distributed() -> bool:
|
def is_distributed() -> bool:
|
||||||
"""Check if distributed training is initialized."""
|
"""Check if distributed training is initialized."""
|
||||||
global distributed_state # pylint: disable=global-statement
|
init_distributed_state()
|
||||||
|
|
||||||
if distributed_state is None:
|
if distributed_state is None:
|
||||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
return False
|
||||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
|
||||||
|
|
||||||
return distributed_state.use_distributed and distributed_state.initialized
|
return distributed_state.use_distributed and distributed_state.initialized
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -537,6 +537,12 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
|
||||||
|
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||||
|
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||||
|
# to model load.
|
||||||
|
init_distributed_state()
|
||||||
|
|
||||||
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
||||||
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user