don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing * add e2e test to validate preprocess cli with deepspeed * ignore duplicate code for cfg
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -95,6 +96,7 @@ def do_cli(
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
|
||||
@@ -546,7 +546,10 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||
# to model load.
|
||||
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
|
||||
if (
|
||||
int(os.environ.get("WORLD_SIZE", "1")) == 1
|
||||
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
|
||||
):
|
||||
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
|
||||
|
||||
Reference in New Issue
Block a user