Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
83d904a27d fix the context manager call
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-03 22:49:28 -04:00
Wing Lian
5e4a760ad8 start to swap out for accelerate partial state 2023-09-03 22:41:29 -04:00
3 changed files with 25 additions and 41 deletions

View File

@@ -11,6 +11,7 @@ import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from accelerate.state import PartialState
from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm
@@ -24,11 +25,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING:
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
dist_state = PartialState()
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
"subject": example["subject"],
}
with zero_first(is_main_process()):
with dist_state.main_process_first():
bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
@@ -258,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r)
barrier()
dist_state.wait_for_everyone()
local_bench_names = bench_names
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
# Gather results from all GPUs to GPU 0

View File

@@ -7,6 +7,7 @@ from pathlib import Path
from typing import Tuple, Union
import torch
from accelerate.state import PartialState
from datasets import (
Dataset,
DatasetDict,
@@ -42,7 +43,6 @@ from axolotl.prompters import (
SummarizeTLDRPrompter,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -50,11 +50,12 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
state = PartialState()
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
with state.main_process_first():
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
with zero_first(is_main_process()):
with state.main_process_first():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
@@ -507,7 +508,7 @@ def load_prepare_datasets(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with zero_first(is_main_process()):
with state.main_process_first():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,

View File

@@ -1,29 +1,27 @@
"""
utility helpers for distributed checks
"""
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
from accelerate import Accelerator
from accelerate import DistributedType
from accelerate.state import PartialState
from accelerate.utils import wait_for_everyone
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
state = PartialState()
def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
return state.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_CPU,
DistributedType.DEEPSPEED,
DistributedType.FSDP,
)
def barrier():
@@ -31,34 +29,19 @@ def barrier():
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
if is_distributed():
dist.barrier()
wait_for_everyone()
def is_main_process():
def is_main_process() -> bool:
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
return dist.get_rank() == 0
return state.is_main_process
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def get_world_size() -> int:
return state.num_processes
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
@@ -76,7 +59,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():
if not state.is_main_process:
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]