use context manager to run things on rank0 before others (#397)

This commit is contained in:
Wing Lian
2023-08-15 00:10:47 -04:00
committed by GitHub
parent 1687be6a35
commit fc2d6be96d
3 changed files with 18 additions and 21 deletions

View File

@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import (
@@ -198,17 +198,10 @@ def train(
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps

View File

@@ -41,7 +41,7 @@ from axolotl.prompters import (
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.distributed import is_main_process, zero_first
LOG = logging.getLogger("axolotl")
@@ -440,7 +440,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
if is_main_process():
with zero_first(is_main_process()):
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
@@ -448,16 +448,6 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

View File

@@ -1,6 +1,8 @@
"""
utility helpers for distributed checks
"""
from contextlib import contextmanager
import torch.distributed as dist
from accelerate import Accelerator
@@ -39,3 +41,15 @@ def is_main_process():
if not is_distributed():
return True
return dist.get_rank() == 0
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()