diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 89c77dca4..3520aff10 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -3,15 +3,29 @@ DPO trainer for axolotl """ import gc +import random from functools import wraps -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union +import pandas as pd import torch +import wandb +from accelerate import PartialState +from datasets import Dataset, IterableDataset from peft.optimizers import create_loraplus_optimizer from torch import nn -from transformers import Trainer +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, +) +from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer +from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt +from trl.trainer.utils import log_table_to_comet_experiment from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.utils import ( @@ -81,6 +95,64 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) + # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[ + PreTrainedTokenizerBase, + BaseImageProcessor, + FeatureExtractionMixin, + ProcessorMixin, + ], + args: DPOConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Build the kwargs for the `map` function + map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Extract prompt if needed + if isinstance( + dataset, Dataset + ): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + + # Apply the chat template if needed + if isinstance( + dataset, Dataset + ): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, + **map_kwargs, + ) + + # Tokenize the dataset + if isinstance( + dataset, Dataset + ): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + dataset = dataset.map( + self.tokenize_row if not self.is_vision_model else self.process_row, + remove_columns=["chosen", "rejected"], + fn_kwargs={ + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": False, + }, + **map_kwargs, + ) + + return dataset + @staticmethod def tokenize_row( features, @@ -124,3 +196,67 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): gc.collect() torch.cuda.empty_cache() return loss + + # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample( + range(num_samples), k=self.args.eval_batch_size + ) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = ( + self.generate_from_model_and_ref(self.model, random_batch) + ) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch_dataset["prompt"], + policy_output_decoded, + ref_output_decoded, + ) + ], + ) + if "wandb" in self.args.report_to and self.accelerator.is_main_process: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, + description, + prediction_loss_only, + ignore_keys, + metric_key_prefix, + ) + + return initial_output diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 4c7b71292..135de61a3 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -204,7 +204,37 @@ def load_prepare_preference_datasets(cfg): else: eval_dataset = load_split(cfg.test_datasets, cfg) if not eval_dataset: - eval_dataset = None + if cfg.val_set_size: + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache + to_hash_train = ( + train_dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "train" + + "|" + + str(cfg.seed or 42) + ) + to_hash_test = ( + train_dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "test" + + "|" + + str(cfg.seed or 42) + ) + train_fingerprint = md5(to_hash_train) + test_fingerprint = md5(to_hash_test) + ds_w_test_split = train_dataset.train_test_split( + test_size=cfg.val_set_size, + seed=cfg.seed, + shuffle=False, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) + eval_dataset = ds_w_test_split["test"] + train_dataset = ds_w_test_split["train"] if not train_is_preprocessed: _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)