Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
83d904a27d fix the context manager call
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-03 22:49:28 -04:00
Wing Lian
5e4a760ad8 start to swap out for accelerate partial state 2023-09-03 22:41:29 -04:00
3 changed files with 25 additions and 41 deletions

View File

@@ -11,6 +11,7 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
@@ -24,11 +25,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks, gather_scalar_from_all_ranks,
get_world_size, get_world_size,
is_main_process, is_main_process,
zero_first,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
"subject": example["subject"], "subject": example["subject"],
} }
with zero_first(is_main_process()): with dist_state.main_process_first():
bench_dataset = bench_dataset.map(tokenize_evals) bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
@@ -258,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p) bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r) bench_names[s]["refs"].append(r)
barrier() dist_state.wait_for_everyone()
local_bench_names = bench_names local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0 # Gather results from all GPUs to GPU 0

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Tuple, Union from typing import Tuple, Union
import torch import torch
from accelerate.state import PartialState
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -42,7 +43,6 @@ from axolotl.prompters import (
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
@@ -50,11 +50,12 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
@@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
with zero_first(is_main_process()): with state.main_process_first():
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
@@ -507,7 +508,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False to_hash_test.encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
with zero_first(is_main_process()): with state.main_process_first():
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=cfg.val_set_size, test_size=cfg.val_set_size,
shuffle=False, shuffle=False,

View File

@@ -1,29 +1,27 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
accelerate = None # pylint: disable=invalid-name accelerate = None # pylint: disable=invalid-name
state = PartialState()
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed(): def is_distributed():
""" """
Check if distributed training is initialized. Check if distributed training is initialized.
""" """
global accelerate # pylint: disable=global-statement return state.distributed_type in (
if not accelerate: DistributedType.MULTI_GPU,
accelerate = Accelerator() DistributedType.MULTI_CPU,
return dist.is_available() and dist.is_initialized() DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
def barrier(): def barrier():
@@ -31,34 +29,19 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further. reach the barrier before proceeding further.
""" """
if is_distributed(): wait_for_everyone()
dist.barrier()
def is_main_process(): def is_main_process() -> bool:
""" """
Check if the current process is the main process. Check if the current process is the main process.
If not in distributed mode, always return True. If not in distributed mode, always return True.
""" """
if not is_distributed(): return state.is_main_process
return True
return dist.get_rank() == 0
def get_world_size(): def get_world_size() -> int:
return int(os.getenv("WORLD_SIZE", "1")) return state.num_processes
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
@@ -76,7 +59,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn() value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process(): if not state.is_main_process:
dist.gather(value_tensor, dst=0) dist.gather(value_tensor, dst=0)
else: else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]