don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing * add e2e test to validate preprocess cli with deepspeed * ignore duplicate code for cfg
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""CLI to run preprocessing of a dataset."""
|
"""CLI to run preprocessing of a dataset."""
|
||||||
|
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@@ -95,6 +96,7 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parsed_cfg.is_preprocess = True
|
parsed_cfg.is_preprocess = True
|
||||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||||
|
|||||||
@@ -546,7 +546,10 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||||
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||||
# to model load.
|
# to model load.
|
||||||
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
|
if (
|
||||||
|
int(os.environ.get("WORLD_SIZE", "1")) == 1
|
||||||
|
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
|
||||||
|
):
|
||||||
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||||
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||||
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
|
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
|
||||||
|
|||||||
58
tests/e2e/test_preprocess.py
Normal file
58
tests/e2e/test_preprocess.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""E2E Test the preprocess cli"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreprocess:
|
||||||
|
"""test cases for preprocess"""
|
||||||
|
|
||||||
|
def test_w_deepspeed(self, temp_dir):
|
||||||
|
"""make sure preproces doesn't choke when using deepspeed in the config"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"preprocess",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (Path(temp_dir) / "last_run_prepared").exists()
|
||||||
Reference in New Issue
Block a user