use context manager to run things on rank0 before others (#397)
This commit is contained in:
@@ -41,7 +41,7 @@ from axolotl.prompters import (
|
||||
ShareGPTPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
)
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -440,7 +440,7 @@ def load_prepare_datasets(
|
||||
to_hash_test.encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
if is_main_process():
|
||||
with zero_first(is_main_process()):
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
@@ -448,16 +448,6 @@ def load_prepare_datasets(
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
if not is_main_process():
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
barrier()
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
utility helpers for distributed checks
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch.distributed as dist
|
||||
from accelerate import Accelerator
|
||||
|
||||
@@ -39,3 +41,15 @@ def is_main_process():
|
||||
if not is_distributed():
|
||||
return True
|
||||
return dist.get_rank() == 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
"""
|
||||
runs the wrapped context so that rank 0 runs first before other ranks
|
||||
"""
|
||||
if not is_main: # other ranks wait first
|
||||
barrier()
|
||||
yield
|
||||
if is_main: # then rank 0 waits after it has run the context
|
||||
barrier()
|
||||
|
||||
Reference in New Issue
Block a user