use context manager to run things on rank0 before others (#397)

This commit is contained in:
Wing Lian
2023-08-15 00:10:47 -04:00
committed by GitHub
parent 1687be6a35
commit fc2d6be96d
3 changed files with 18 additions and 21 deletions

View File

@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import (
@@ -198,17 +198,10 @@ def train(
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps