support val_set_size for splitting test split from train with DPO (#2572)
This commit is contained in:
@@ -3,15 +3,29 @@ DPO trainer for axolotl
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import random
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
import wandb
|
||||||
|
from accelerate import PartialState
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Trainer
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import (
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
|
Trainer,
|
||||||
|
)
|
||||||
|
from transformers.trainer_utils import EvalLoopOutput
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
|
||||||
|
from trl.trainer.utils import log_table_to_comet_experiment
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
@@ -81,6 +95,64 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def _prepare_dataset(
|
||||||
|
self,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
processing_class: Union[
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
BaseImageProcessor,
|
||||||
|
FeatureExtractionMixin,
|
||||||
|
ProcessorMixin,
|
||||||
|
],
|
||||||
|
args: DPOConfig,
|
||||||
|
dataset_name: str,
|
||||||
|
) -> Union[Dataset, IterableDataset]:
|
||||||
|
# Build the kwargs for the `map` function
|
||||||
|
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
|
||||||
|
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||||
|
map_kwargs["num_proc"] = args.dataset_num_proc
|
||||||
|
|
||||||
|
with PartialState().main_process_first():
|
||||||
|
# Extract prompt if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
||||||
|
|
||||||
|
# Apply the chat template if needed
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||||
|
dataset = dataset.map(
|
||||||
|
maybe_apply_chat_template,
|
||||||
|
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the dataset
|
||||||
|
if isinstance(
|
||||||
|
dataset, Dataset
|
||||||
|
): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
self.tokenize_row if not self.is_vision_model else self.process_row,
|
||||||
|
remove_columns=["chosen", "rejected"],
|
||||||
|
fn_kwargs={
|
||||||
|
"processing_class": processing_class,
|
||||||
|
"max_prompt_length": args.max_prompt_length,
|
||||||
|
"max_completion_length": args.max_completion_length,
|
||||||
|
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
||||||
|
"add_special_tokens": False,
|
||||||
|
},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
features,
|
features,
|
||||||
@@ -124,3 +196,67 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
|
||||||
|
def evaluation_loop(
|
||||||
|
self,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
description: str,
|
||||||
|
prediction_loss_only: Optional[bool] = None,
|
||||||
|
ignore_keys: Optional[list[str]] = None,
|
||||||
|
metric_key_prefix: str = "eval",
|
||||||
|
) -> EvalLoopOutput:
|
||||||
|
"""
|
||||||
|
Overriding built-in evaluation loop to store metrics for each batch.
|
||||||
|
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||||
|
|
||||||
|
Works both with or without labels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sample and save to game log if requested (for one batch to save time)
|
||||||
|
if self.generate_during_eval:
|
||||||
|
# Generate random indices within the range of the total number of samples
|
||||||
|
num_samples = len(dataloader.dataset)
|
||||||
|
random_indices = random.sample(
|
||||||
|
range(num_samples), k=self.args.eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
||||||
|
random_batch_dataset = dataloader.dataset.select(random_indices)
|
||||||
|
random_batch = self.data_collator(random_batch_dataset)
|
||||||
|
random_batch = self._prepare_inputs(random_batch)
|
||||||
|
|
||||||
|
policy_output_decoded, ref_output_decoded = (
|
||||||
|
self.generate_from_model_and_ref(self.model, random_batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
table = pd.DataFrame(
|
||||||
|
columns=["Prompt", "Policy", "Ref Model"],
|
||||||
|
data=[
|
||||||
|
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||||
|
for prompt, pol, ref in zip(
|
||||||
|
random_batch_dataset["prompt"],
|
||||||
|
policy_output_decoded,
|
||||||
|
ref_output_decoded,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
||||||
|
wandb.log({"game_log": wandb.Table(data=table)})
|
||||||
|
|
||||||
|
if "comet_ml" in self.args.report_to:
|
||||||
|
log_table_to_comet_experiment(
|
||||||
|
name="game_log.csv",
|
||||||
|
table=table,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base evaluation
|
||||||
|
initial_output = super().evaluation_loop(
|
||||||
|
dataloader,
|
||||||
|
description,
|
||||||
|
prediction_loss_only,
|
||||||
|
ignore_keys,
|
||||||
|
metric_key_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
return initial_output
|
||||||
|
|||||||
@@ -204,7 +204,37 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
else:
|
else:
|
||||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||||
if not eval_dataset:
|
if not eval_dataset:
|
||||||
eval_dataset = None
|
if cfg.val_set_size:
|
||||||
|
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||||
|
to_hash_train = (
|
||||||
|
train_dataset._fingerprint # pylint: disable=protected-access
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.val_set_size)
|
||||||
|
+ "|"
|
||||||
|
+ "train"
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.seed or 42)
|
||||||
|
)
|
||||||
|
to_hash_test = (
|
||||||
|
train_dataset._fingerprint # pylint: disable=protected-access
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.val_set_size)
|
||||||
|
+ "|"
|
||||||
|
+ "test"
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.seed or 42)
|
||||||
|
)
|
||||||
|
train_fingerprint = md5(to_hash_train)
|
||||||
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
ds_w_test_split = train_dataset.train_test_split(
|
||||||
|
test_size=cfg.val_set_size,
|
||||||
|
seed=cfg.seed,
|
||||||
|
shuffle=False,
|
||||||
|
train_new_fingerprint=train_fingerprint,
|
||||||
|
test_new_fingerprint=test_fingerprint,
|
||||||
|
)
|
||||||
|
eval_dataset = ds_w_test_split["test"]
|
||||||
|
train_dataset = ds_w_test_split["train"]
|
||||||
|
|
||||||
if not train_is_preprocessed:
|
if not train_is_preprocessed:
|
||||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user