diff --git a/scripts/finetune.py b/scripts/finetune.py index bb96d9789..944876a33 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import barrier, is_main_process +from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import ( @@ -198,17 +198,10 @@ def train( train_dataset = train_dataset.with_format("torch") eval_dataset = None - if is_main_process(): - # process on rank 0 first so it gets cached so other ranks load from cache + with zero_first(is_main_process()): train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset ) - barrier() - if not is_main_process(): - train_dataset, eval_dataset = process_datasets_for_packing( - cfg, train_dataset, eval_dataset - ) - barrier() if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fc30c4ce3..6ebfb9d17 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -41,7 +41,7 @@ from axolotl.prompters import ( ShareGPTPrompter, SummarizeTLDRPrompter, ) -from axolotl.utils.distributed import barrier, is_main_process +from axolotl.utils.distributed import is_main_process, zero_first LOG = logging.getLogger("axolotl") @@ -440,7 +440,7 @@ def load_prepare_datasets( to_hash_test.encode(), usedforsecurity=False ).hexdigest() - if is_main_process(): + with zero_first(is_main_process()): dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, @@ -448,16 +448,6 @@ def load_prepare_datasets( train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) - barrier() - if not is_main_process(): - dataset = dataset.train_test_split( - test_size=cfg.val_set_size, - shuffle=False, - seed=cfg.seed or 42, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - barrier() train_dataset = dataset["train"] eval_dataset = dataset["test"] diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 345b9640c..b3ea07c05 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,6 +1,8 @@ """ utility helpers for distributed checks """ +from contextlib import contextmanager + import torch.distributed as dist from accelerate import Accelerator @@ -39,3 +41,15 @@ def is_main_process(): if not is_distributed(): return True return dist.get_rank() == 0 + + +@contextmanager +def zero_first(is_main): + """ + runs the wrapped context so that rank 0 runs first before other ranks + """ + if not is_main: # other ranks wait first + barrier() + yield + if is_main: # then rank 0 waits after it has run the context + barrier()