use context manager to run things on rank0 before others (#397)

This commit is contained in:
Wing Lian
2023-08-15 00:10:47 -04:00
committed by GitHub
parent 1687be6a35
commit fc2d6be96d
3 changed files with 18 additions and 21 deletions

View File

@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
@@ -198,17 +198,10 @@ def train(
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
if is_main_process(): with zero_first(is_main_process()):
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if cfg.max_steps: if cfg.max_steps:
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps

View File

@@ -41,7 +41,7 @@ from axolotl.prompters import (
ShareGPTPrompter, ShareGPTPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import is_main_process, zero_first
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -440,7 +440,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False to_hash_test.encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
if is_main_process(): with zero_first(is_main_process()):
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,
@@ -448,16 +448,6 @@ def load_prepare_datasets(
train_new_fingerprint=train_fingerprint, train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint, test_new_fingerprint=test_fingerprint,
) )
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]

View File

@@ -1,6 +1,8 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
from contextlib import contextmanager
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import Accelerator
@@ -39,3 +41,15 @@ def is_main_process():
if not is_distributed(): if not is_distributed():
return True return True
return dist.get_rank() == 0 return dist.get_rank() == 0
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()