Compare commits
2 Commits
fix/doc-ke
...
multi-gpu-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83d904a27d | ||
|
|
5e4a760ad8 |
@@ -11,6 +11,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from accelerate.state import PartialState
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -24,11 +25,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
|
||||||
gather_scalar_from_all_ranks,
|
gather_scalar_from_all_ranks,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -36,6 +35,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
dist_state = PartialState()
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
@@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
"subject": example["subject"],
|
"subject": example["subject"],
|
||||||
}
|
}
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with dist_state.main_process_first():
|
||||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||||
|
|
||||||
@@ -258,7 +258,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||||
bench_names[s]["preds"].append(p)
|
bench_names[s]["preds"].append(p)
|
||||||
bench_names[s]["refs"].append(r)
|
bench_names[s]["refs"].append(r)
|
||||||
barrier()
|
dist_state.wait_for_everyone()
|
||||||
local_bench_names = bench_names
|
local_bench_names = bench_names
|
||||||
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||||
# Gather results from all GPUs to GPU 0
|
# Gather results from all GPUs to GPU 0
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate.state import PartialState
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -42,7 +43,6 @@ from axolotl.prompters import (
|
|||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -50,11 +50,12 @@ from axolotl.utils.trainer import (
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
state = PartialState()
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with state.main_process_first():
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
@@ -69,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with state.main_process_first():
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
@@ -507,7 +508,7 @@ def load_prepare_datasets(
|
|||||||
to_hash_test.encode(), usedforsecurity=False
|
to_hash_test.encode(), usedforsecurity=False
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with state.main_process_first():
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=cfg.val_set_size,
|
test_size=cfg.val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
|||||||
@@ -1,29 +1,27 @@
|
|||||||
"""
|
"""
|
||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import DistributedType
|
||||||
|
from accelerate.state import PartialState
|
||||||
|
from accelerate.utils import wait_for_everyone
|
||||||
|
|
||||||
accelerate = None # pylint: disable=invalid-name
|
accelerate = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
state = PartialState()
|
||||||
def load_accelerate():
|
|
||||||
global accelerate # pylint: disable=global-statement
|
|
||||||
accelerate = Accelerator()
|
|
||||||
|
|
||||||
|
|
||||||
def is_distributed():
|
def is_distributed():
|
||||||
"""
|
"""
|
||||||
Check if distributed training is initialized.
|
Check if distributed training is initialized.
|
||||||
"""
|
"""
|
||||||
global accelerate # pylint: disable=global-statement
|
return state.distributed_type in (
|
||||||
if not accelerate:
|
DistributedType.MULTI_GPU,
|
||||||
accelerate = Accelerator()
|
DistributedType.MULTI_CPU,
|
||||||
return dist.is_available() and dist.is_initialized()
|
DistributedType.DEEPSPEED,
|
||||||
|
DistributedType.FSDP,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def barrier():
|
def barrier():
|
||||||
@@ -31,34 +29,19 @@ def barrier():
|
|||||||
Acts as a barrier to wait for all processes. This ensures that all processes
|
Acts as a barrier to wait for all processes. This ensures that all processes
|
||||||
reach the barrier before proceeding further.
|
reach the barrier before proceeding further.
|
||||||
"""
|
"""
|
||||||
if is_distributed():
|
wait_for_everyone()
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
def is_main_process():
|
def is_main_process() -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the current process is the main process.
|
Check if the current process is the main process.
|
||||||
If not in distributed mode, always return True.
|
If not in distributed mode, always return True.
|
||||||
"""
|
"""
|
||||||
if not is_distributed():
|
return state.is_main_process
|
||||||
return True
|
|
||||||
return dist.get_rank() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size() -> int:
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return state.num_processes
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def zero_first(is_main):
|
|
||||||
"""
|
|
||||||
runs the wrapped context so that rank 0 runs first before other ranks
|
|
||||||
"""
|
|
||||||
if not is_main: # other ranks wait first
|
|
||||||
barrier()
|
|
||||||
yield
|
|
||||||
if is_main: # then rank 0 waits after it has run the context
|
|
||||||
barrier()
|
|
||||||
|
|
||||||
|
|
||||||
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||||
@@ -76,7 +59,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||||
|
|
||||||
if not is_main_process():
|
if not state.is_main_process:
|
||||||
dist.gather(value_tensor, dst=0)
|
dist.gather(value_tensor, dst=0)
|
||||||
else:
|
else:
|
||||||
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
||||||
|
|||||||
Reference in New Issue
Block a user