use context manager to run things on rank0 before others (#397)
This commit is contained in:
@@ -21,7 +21,7 @@ from axolotl.logging_config import configure_logging
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
@@ -198,17 +198,10 @@ def train(
|
|||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
if is_main_process():
|
with zero_first(is_main_process()):
|
||||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
barrier()
|
|
||||||
if not is_main_process():
|
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
|
||||||
cfg, train_dataset, eval_dataset
|
|
||||||
)
|
|
||||||
barrier()
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from axolotl.prompters import (
|
|||||||
ShareGPTPrompter,
|
ShareGPTPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.distributed import barrier, is_main_process
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -440,7 +440,7 @@ def load_prepare_datasets(
|
|||||||
to_hash_test.encode(), usedforsecurity=False
|
to_hash_test.encode(), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if is_main_process():
|
with zero_first(is_main_process()):
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@@ -448,16 +448,6 @@ def load_prepare_datasets(
|
|||||||
train_new_fingerprint=train_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
test_new_fingerprint=test_fingerprint,
|
test_new_fingerprint=test_fingerprint,
|
||||||
)
|
)
|
||||||
barrier()
|
|
||||||
if not is_main_process():
|
|
||||||
dataset = dataset.train_test_split(
|
|
||||||
test_size=cfg.val_set_size,
|
|
||||||
shuffle=False,
|
|
||||||
seed=cfg.seed or 42,
|
|
||||||
train_new_fingerprint=train_fingerprint,
|
|
||||||
test_new_fingerprint=test_fingerprint,
|
|
||||||
)
|
|
||||||
barrier()
|
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@@ -39,3 +41,15 @@ def is_main_process():
|
|||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
return True
|
return True
|
||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def zero_first(is_main):
|
||||||
|
"""
|
||||||
|
runs the wrapped context so that rank 0 runs first before other ranks
|
||||||
|
"""
|
||||||
|
if not is_main: # other ranks wait first
|
||||||
|
barrier()
|
||||||
|
yield
|
||||||
|
if is_main: # then rank 0 waits after it has run the context
|
||||||
|
barrier()
|
||||||
|
|||||||
Reference in New Issue
Block a user