diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 71676bc84..fb91f81e5 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -63,15 +63,6 @@ Start from Stage 1 -> Stage 2 -> Stage 3. ::: -::: {.callout-tip} - -Using ZeRO Stage 3 with Single-GPU training - -ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables: -`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500` - -::: - ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} ::: {.callout-note} diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 08038cb18..43f76c0cd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -547,6 +547,13 @@ def setup_deepspeed_env(cfg, stage=None): if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + device_count = torch.cuda.device_count() + if device_count == 1: + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("MASTER_ADDR", "0.0.0.0") # nosec B104 + os.environ.setdefault("MASTER_PORT", "29500") + # NOTE(djsaunde): The distribued state cannot be initialized prior to the # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior # to model load.