From 5e4a760ad84414af1c2ad3cf641afa2ca6530e7d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 3 Sep 2023 22:41:29 -0400 Subject: [PATCH] start to swap out for accelerate partial state --- src/axolotl/utils/callbacks.py | 8 +++--- src/axolotl/utils/data.py | 9 +++--- src/axolotl/utils/distributed.py | 49 +++++++++++--------------------- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ee5acfd55..4f633cd9e 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -11,6 +11,7 @@ import numpy as np import pandas as pd import torch import torch.distributed as dist +from accelerate.state import PartialState from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -24,11 +25,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( - barrier, gather_scalar_from_all_ranks, get_world_size, is_main_process, - zero_first, ) if TYPE_CHECKING: @@ -36,6 +35,7 @@ if TYPE_CHECKING: LOG = logging.getLogger("axolotl.callbacks") IGNORE_INDEX = -100 +dist_state = PartialState() class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer): "subject": example["subject"], } - with zero_first(is_main_process()): + with dist_state.main_process_first: bench_dataset = bench_dataset.map(tokenize_evals) bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) @@ -258,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer): for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name bench_names[s]["preds"].append(p) bench_names[s]["refs"].append(r) - barrier() + dist_state.wait_for_everyone() local_bench_names = bench_names gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] # Gather results from all GPUs to GPU 0 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 20d0fcfb8..66029222c 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Tuple, Union import torch +from accelerate.state import PartialState from datasets import ( Dataset, DatasetDict, @@ -42,7 +43,6 @@ from axolotl.prompters import ( SummarizeTLDRPrompter, ) from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -50,11 +50,12 @@ from axolotl.utils.trainer import ( LOG = logging.getLogger("axolotl") DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" +state = PartialState() def prepare_dataset(cfg, tokenizer): if not cfg.pretraining_dataset: - with zero_first(is_main_process()): + with state.main_process_first(): train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer): train_dataset = train_dataset.with_format("torch") eval_dataset = None - with zero_first(is_main_process()): + with state.main_process_first(): train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset ) @@ -507,7 +508,7 @@ def load_prepare_datasets( to_hash_test.encode(), usedforsecurity=False ).hexdigest() - with zero_first(is_main_process()): + with state.main_process_first(): dataset = dataset.train_test_split( test_size=cfg.val_set_size, shuffle=False, diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 38d0d1e05..9f8e9922b 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,29 +1,27 @@ """ utility helpers for distributed checks """ -import os -from contextlib import contextmanager - import torch import torch.distributed as dist -from accelerate import Accelerator +from accelerate import DistributedType +from accelerate.state import PartialState +from accelerate.utils import wait_for_everyone accelerate = None # pylint: disable=invalid-name - -def load_accelerate(): - global accelerate # pylint: disable=global-statement - accelerate = Accelerator() +state = PartialState() def is_distributed(): """ Check if distributed training is initialized. """ - global accelerate # pylint: disable=global-statement - if not accelerate: - accelerate = Accelerator() - return dist.is_available() and dist.is_initialized() + return state.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_CPU, + DistributedType.DEEPSPEED, + DistributedType.FSDP, + ) def barrier(): @@ -31,34 +29,19 @@ def barrier(): Acts as a barrier to wait for all processes. This ensures that all processes reach the barrier before proceeding further. """ - if is_distributed(): - dist.barrier() + wait_for_everyone() -def is_main_process(): +def is_main_process() -> bool: """ Check if the current process is the main process. If not in distributed mode, always return True. """ - if not is_distributed(): - return True - return dist.get_rank() == 0 + return state.is_main_process -def get_world_size(): - return int(os.getenv("WORLD_SIZE", "1")) - - -@contextmanager -def zero_first(is_main): - """ - runs the wrapped context so that rank 0 runs first before other ranks - """ - if not is_main: # other ranks wait first - barrier() - yield - if is_main: # then rank 0 waits after it has run the context - barrier() +def get_world_size() -> int: + return state.num_processes def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name @@ -76,7 +59,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() - if not is_main_process(): + if not state.is_main_process: dist.gather(value_tensor, dst=0) else: gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]