don't init distributed for deepspeed if preprocessing (#2920)

* don't init distributed for deepspeed if preprocessing

* add e2e test to validate preprocess cli with deepspeed

* ignore duplicate code for cfg
This commit is contained in:
Wing Lian
2025-07-14 14:19:19 -04:00
committed by GitHub
parent 37edbe4999
commit ca4d4ef793
3 changed files with 64 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset."""
import os
import warnings
from pathlib import Path
from typing import Union
@@ -95,6 +96,7 @@ def do_cli(
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)

View File

@@ -546,7 +546,10 @@ def setup_deepspeed_env(cfg, stage=None):
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
if (
int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
):
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")

View File

@@ -0,0 +1,58 @@
"""E2E Test the preprocess cli"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
class TestPreprocess:
"""test cases for preprocess"""
def test_w_deepspeed(self, temp_dir):
"""make sure preproces doesn't choke when using deepspeed in the config"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"dataset_prepared_path": temp_dir + "/last_run_prepared",
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"preprocess",
str(Path(temp_dir) / "config.yaml"),
]
)
assert (Path(temp_dir) / "last_run_prepared").exists()