fix for gather across multiple gpus
This commit is contained in:
@@ -212,7 +212,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-1] in abcd_idx)
|
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||||
|
|
||||||
class BenchEvalCallback(TrainerCallback):
|
class BenchEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
@@ -248,7 +248,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
preds.append(torch.argmax(logit_abcd).item())
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
refs += [
|
refs += [
|
||||||
abcd_idx.index(label) if labels in abcd_idx else -1
|
abcd_idx.index(label) if label in abcd_idx else -1
|
||||||
for label in labels.tolist()
|
for label in labels.tolist()
|
||||||
]
|
]
|
||||||
loss_bench += loss.item()
|
loss_bench += loss.item()
|
||||||
@@ -259,19 +259,24 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
bench_names[s]["preds"].append(p)
|
bench_names[s]["preds"].append(p)
|
||||||
bench_names[s]["refs"].append(r)
|
bench_names[s]["refs"].append(r)
|
||||||
barrier()
|
barrier()
|
||||||
bench_loss = sum(
|
|
||||||
gather_scalar_from_all_ranks(lambda: loss_bench, get_world_size())
|
|
||||||
) / sum(
|
|
||||||
gather_scalar_from_all_ranks(lambda: len(data_loader), get_world_size())
|
|
||||||
)
|
|
||||||
results = {"bench_loss": bench_loss}
|
|
||||||
|
|
||||||
local_bench_names = bench_names
|
local_bench_names = bench_names
|
||||||
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||||
# Gather results from all GPUs to GPU 0
|
# Gather results from all GPUs to GPU 0
|
||||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
|
||||||
|
|
||||||
if is_main_process():
|
loss_bench_ranks = gather_scalar_from_all_ranks(
|
||||||
|
lambda: loss_bench, get_world_size()
|
||||||
|
)
|
||||||
|
len_data_loader_ranks = gather_scalar_from_all_ranks(
|
||||||
|
lambda: len(data_loader), get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_main_process():
|
||||||
|
dist.gather_object(local_bench_names, dst=0)
|
||||||
|
else:
|
||||||
|
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||||
|
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||||
|
results = {"bench_loss": bench_loss}
|
||||||
|
|
||||||
# Combine results from all GPUs
|
# Combine results from all GPUs
|
||||||
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
||||||
for bench_name in gathered_bench_names:
|
for bench_name in gathered_bench_names:
|
||||||
|
|||||||
@@ -76,15 +76,12 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||||
|
|
||||||
# Placeholder tensor for gathering results
|
if not is_main_process():
|
||||||
if is_main_process():
|
dist.gather(value_tensor, dst=0)
|
||||||
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
|
||||||
else:
|
else:
|
||||||
gathered_tensors = None
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
||||||
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
||||||
|
|
||||||
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
|
||||||
|
|
||||||
if is_main_process():
|
|
||||||
# Convert tensors back to their original type (int or float)
|
# Convert tensors back to their original type (int or float)
|
||||||
gathered_values = []
|
gathered_values = []
|
||||||
for tensor in gathered_tensors:
|
for tensor in gathered_tensors:
|
||||||
|
|||||||
Reference in New Issue
Block a user