Compare commits
23 Commits
model-load
...
benchmark-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3de28942c | ||
|
|
45848a9285 | ||
|
|
d6cea18034 | ||
|
|
606846e0a5 | ||
|
|
a6c9223114 | ||
|
|
8b16ecd448 | ||
|
|
f5db88a10d | ||
|
|
99d844f215 | ||
|
|
aefd4d74fa | ||
|
|
24b0e93235 | ||
|
|
2455254b92 | ||
|
|
918e040601 | ||
|
|
ef062d8fcb | ||
|
|
d4c8b66f3d | ||
|
|
64e9824d3e | ||
|
|
1134654c98 | ||
|
|
2fc756c289 | ||
|
|
943b84c490 | ||
|
|
6f166464d8 | ||
|
|
e3b07402a7 | ||
|
|
8d3c8a3eab | ||
|
|
c30120e684 | ||
|
|
9aed60fa54 |
@@ -4,6 +4,7 @@ transformers @ git+https://github.com/huggingface/transformers.git
|
|||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||||
addict
|
addict
|
||||||
|
evaluate
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets
|
datasets
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
"""Callbacks for Trainer class"""
|
"""Callbacks for Trainer class"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
|
import evaluate
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainerControl,
|
TrainerControl,
|
||||||
@@ -13,8 +23,19 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.distributed import (
|
||||||
|
barrier,
|
||||||
|
gather_scalar_from_all_ranks,
|
||||||
|
get_world_size,
|
||||||
|
is_main_process,
|
||||||
|
zero_first,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||||
@@ -96,3 +117,192 @@ class GPUStatsCallback(
|
|||||||
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
|
||||||
self.logged = True
|
self.logged = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def bench_eval_callback_factory(trainer, tokenizer):
|
||||||
|
accuracy = evaluate.load("accuracy")
|
||||||
|
abcd_idx = [
|
||||||
|
tokenizer("A", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("E", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||||
|
]
|
||||||
|
bench_split = "eval"
|
||||||
|
|
||||||
|
def transform_bench_subject(example):
|
||||||
|
# Split on ':' and trim whitespace
|
||||||
|
parts = example["subject"].split(":")
|
||||||
|
first_part = (
|
||||||
|
parts[0].strip().lower().replace("-", "_")
|
||||||
|
) # Lowercase the first part
|
||||||
|
second_part = (
|
||||||
|
parts[1].strip().replace("-", "_") if len(parts) > 1 else "all"
|
||||||
|
) # Replace hyphens with underscores
|
||||||
|
|
||||||
|
# Return the transformed values
|
||||||
|
return {"name": first_part, "subject": second_part}
|
||||||
|
|
||||||
|
if trainer.args.bench_dataset == "mmlu-zs":
|
||||||
|
bench_dataset = load_dataset(
|
||||||
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
|
data_files={
|
||||||
|
"eval": "zero_shot_mmlu_val.json",
|
||||||
|
"test": "zero_shot_mmlu_test.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# bench_dataset = bench_dataset.remove_columns("subject")
|
||||||
|
# MMLU Five-shot (Eval/Test only)
|
||||||
|
elif trainer.args.bench_dataset in ["mmlu", "mmlu-fs"]:
|
||||||
|
bench_dataset = load_dataset(
|
||||||
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
|
data_files={
|
||||||
|
"eval": "five_shot_mmlu_val.json",
|
||||||
|
"test": "five_shot_mmlu_test.json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# bench_dataset = bench_dataset.remove_columns('subject')
|
||||||
|
elif "/" in trainer.args.bench_dataset:
|
||||||
|
bench_ds = trainer.args.bench_dataset
|
||||||
|
bench_ds_name = "/".join(bench_ds.split("/", 2)[:2])
|
||||||
|
bench_ds_data_file = "/".join(bench_ds.split("/", 2)[2:])
|
||||||
|
bench_dataset = load_dataset(
|
||||||
|
bench_ds_name,
|
||||||
|
data_files={
|
||||||
|
"eval": bench_ds_data_file,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_dataset["eval"] = bench_dataset["eval"].map(transform_bench_subject)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unhandled value `{trainer.args.bench_dataset}` for bench_dataset training args"
|
||||||
|
)
|
||||||
|
bench_dataset = bench_dataset[trainer.args.bench_split]
|
||||||
|
if trainer.args.max_bench_samples is not None:
|
||||||
|
bench_dataset = bench_dataset.select(range(trainer.args.max_bench_samples))
|
||||||
|
|
||||||
|
def tokenize_evals(example):
|
||||||
|
source = f"{tokenizer.bos_token}{example['input']}"
|
||||||
|
target = f"{example['output']}{tokenizer.eos_token}"
|
||||||
|
|
||||||
|
tokenized_source = tokenizer(
|
||||||
|
source,
|
||||||
|
max_length=2048,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
tokenized_target = tokenizer(
|
||||||
|
target,
|
||||||
|
max_length=2048,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
||||||
|
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
"subject": example["subject"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with zero_first(is_main_process()):
|
||||||
|
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||||
|
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||||
|
|
||||||
|
class BenchEvalCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
TrainerCallback that runs the MMLU evals
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments,
|
||||||
|
state: TrainerState, # pylint: disable=unused-argument
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
data_loader = trainer.get_bench_dataloader(
|
||||||
|
bench_dataset.remove_columns(["input", "subject", "output", "name"])
|
||||||
|
)
|
||||||
|
trainer.model.eval()
|
||||||
|
preds, refs = [], []
|
||||||
|
loss_bench = 0
|
||||||
|
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||||
|
(loss, logits, labels) = trainer.prediction_step(
|
||||||
|
trainer.model,
|
||||||
|
batch,
|
||||||
|
prediction_loss_only=False,
|
||||||
|
)
|
||||||
|
# There are two tokens, the output, and eos token.
|
||||||
|
for i, logit in enumerate(logits):
|
||||||
|
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
|
||||||
|
0
|
||||||
|
][0]
|
||||||
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
|
refs += [
|
||||||
|
abcd_idx.index(label) if label in abcd_idx else -1
|
||||||
|
for label in labels.tolist()
|
||||||
|
]
|
||||||
|
loss_bench += loss.item()
|
||||||
|
# Extract results by subject.
|
||||||
|
bench_name = bench_dataset["name"]
|
||||||
|
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||||
|
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||||
|
bench_names[s]["preds"].append(p)
|
||||||
|
bench_names[s]["refs"].append(r)
|
||||||
|
barrier()
|
||||||
|
local_bench_names = bench_names
|
||||||
|
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||||
|
# Gather results from all GPUs to GPU 0
|
||||||
|
|
||||||
|
loss_bench_ranks = gather_scalar_from_all_ranks(
|
||||||
|
lambda: loss_bench, get_world_size()
|
||||||
|
)
|
||||||
|
len_data_loader_ranks = gather_scalar_from_all_ranks(
|
||||||
|
lambda: len(data_loader), get_world_size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_main_process():
|
||||||
|
dist.gather_object(local_bench_names, dst=0)
|
||||||
|
else:
|
||||||
|
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||||
|
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||||
|
results = {"bench_loss": bench_loss}
|
||||||
|
|
||||||
|
# Combine results from all GPUs
|
||||||
|
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
||||||
|
for bench_name in gathered_bench_names:
|
||||||
|
for name, data in bench_name.items():
|
||||||
|
if name not in combined_bench_names:
|
||||||
|
combined_bench_names[name] = {"refs": [], "preds": []}
|
||||||
|
combined_bench_names[name]["refs"].extend(data["refs"])
|
||||||
|
combined_bench_names[name]["preds"].extend(data["preds"])
|
||||||
|
|
||||||
|
bench_scores = []
|
||||||
|
for (
|
||||||
|
bench_name
|
||||||
|
) in combined_bench_names: # pylint: disable=consider-using-dict-items
|
||||||
|
bench_score = accuracy.compute(
|
||||||
|
references=combined_bench_names[bench_name]["refs"],
|
||||||
|
predictions=combined_bench_names[bench_name]["preds"],
|
||||||
|
)["accuracy"]
|
||||||
|
if not pd.isna(bench_score):
|
||||||
|
results[
|
||||||
|
f"bench_{bench_split}_accuracy_{bench_name}"
|
||||||
|
] = bench_score
|
||||||
|
bench_scores.append(bench_score)
|
||||||
|
else:
|
||||||
|
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
|
||||||
|
bench_scores.append(0.0)
|
||||||
|
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
|
||||||
|
trainer.log(results)
|
||||||
|
|
||||||
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@@ -43,6 +45,10 @@ def is_main_process():
|
|||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main):
|
||||||
"""
|
"""
|
||||||
@@ -53,3 +59,35 @@ def zero_first(is_main):
|
|||||||
yield
|
yield
|
||||||
if is_main: # then rank 0 waits after it has run the context
|
if is_main: # then rank 0 waits after it has run the context
|
||||||
barrier()
|
barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||||
|
"""
|
||||||
|
Run a callable 'fn' on all ranks and gather the results on the specified rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fn (callable): A function that computes the value. This should not have any side effects.
|
||||||
|
- rank (int, optional): The rank that gathers the values. Default is 0.
|
||||||
|
- world_size (int, optional): Total number of processes in the current distributed setup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||||
|
"""
|
||||||
|
value_scalar = fn()
|
||||||
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||||
|
|
||||||
|
if not is_main_process():
|
||||||
|
dist.gather(value_tensor, dst=0)
|
||||||
|
else:
|
||||||
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
||||||
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
||||||
|
|
||||||
|
# Convert tensors back to their original type (int or float)
|
||||||
|
gathered_values = []
|
||||||
|
for tensor in gathered_tensors:
|
||||||
|
if tensor == tensor.int():
|
||||||
|
gathered_values.append(int(tensor.item()))
|
||||||
|
else:
|
||||||
|
gathered_values.append(float(tensor.item()))
|
||||||
|
return gathered_values
|
||||||
|
return None
|
||||||
|
|||||||
@@ -12,9 +12,15 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
import transformers
|
||||||
from datasets import Dataset, set_caching_enabled
|
from datasets import Dataset, set_caching_enabled
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
from torch.utils.data import (
|
||||||
|
DataLoader,
|
||||||
|
DistributedSampler,
|
||||||
|
RandomSampler,
|
||||||
|
SequentialSampler,
|
||||||
|
)
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||||
|
|
||||||
@@ -23,6 +29,7 @@ from axolotl.utils.callbacks import (
|
|||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
|
bench_eval_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
@@ -127,6 +134,27 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
|
bench_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
|
)
|
||||||
|
bench_dataset: Optional[str] = field(
|
||||||
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
|
metadata={
|
||||||
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
|
)
|
||||||
|
max_bench_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -136,6 +164,10 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
|
||||||
|
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
):
|
):
|
||||||
@@ -226,6 +258,31 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
def _get_bench_sampler(
|
||||||
|
self, bench_dataset: Dataset
|
||||||
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.world_size <= 1:
|
||||||
|
return SequentialSampler(bench_dataset)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_bench_dataloader(
|
||||||
|
self,
|
||||||
|
bench_dataset: Dataset,
|
||||||
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self.args.eval_batch_size,
|
||||||
|
"collate_fn": self.bench_data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||||
|
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
# if self.args.sample_packing:
|
# if self.args.sample_packing:
|
||||||
@@ -517,6 +574,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"steps" if cfg.save_steps else "epoch"
|
"steps" if cfg.save_steps else "epoch"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.do_bench_eval:
|
||||||
|
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||||
|
if cfg.bench_dataset:
|
||||||
|
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
@@ -627,8 +689,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
),
|
),
|
||||||
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**data_collator_kwargs,
|
||||||
|
),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.do_bench_eval:
|
||||||
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user