From 19cd83d408ba0d46f2cf6e285488001eeaf4d1c1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 14 Jan 2025 22:07:55 -0500 Subject: [PATCH] rename references to dpo dataset prep to pref data (#2258) --- src/axolotl/common/datasets.py | 10 +++++----- src/axolotl/utils/data/__init__.py | 2 +- src/axolotl/utils/data/rl.py | 2 +- tests/test_datasets.py | 6 +++--- tests/test_exact_deduplication.py | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index d07add29b..c693c26d8 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -11,7 +11,7 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels @@ -103,9 +103,9 @@ def load_preference_datasets( cli_args: Union[PreprocessCliArgs, TrainerCliArgs], ) -> TrainDatasetMeta: """ - Loads one or more training or evaluation datasets for DPO training, calling - `axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug - information. + Loads one or more training or evaluation datasets for RL training using paired + preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`. + Optionally, logs out debug information. Args: cfg: Dictionary mapping `axolotl` config keys to values. @@ -115,7 +115,7 @@ def load_preference_datasets( Dataclass with fields for training and evaluation datasets and the computed `total_num_steps`. """ - train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) + train_dataset, eval_dataset = load_prepare_preference_datasets(cfg) total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 140d02106..7f90bf3cb 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, ) -from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 +from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401 get_dataset_wrapper, load_prepare_datasets, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index edb72f186..9f5c726ab 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -115,7 +115,7 @@ def drop_long_rl_seq( raise ValueError("Unknown RL type") -def load_prepare_dpo_datasets(cfg): +def load_prepare_preference_datasets(cfg): def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] for i, ds_cfg in enumerate(dataset_cfgs): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b1ecfd6d5..49554d370 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from axolotl.utils.data import load_tokenized_prepared_datasets -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault @@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - train_dataset, _ = load_prepare_dpo_datasets(cfg) + train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 assert "conversation" in train_dataset.features @@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase): } ) - train_dataset, _ = load_prepare_dpo_datasets(cfg) + train_dataset, _ = load_prepare_preference_datasets(cfg) assert len(train_dataset) == 1800 assert "conversation" in train_dataset.features diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 2ac6415be..bc0734ed3 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -12,7 +12,7 @@ from datasets import Dataset from transformers import AutoTokenizer from axolotl.utils.data import prepare_dataset -from axolotl.utils.data.rl import load_prepare_dpo_datasets +from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer @@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase): """Verify that loading with deduplication removes duplicates.""" # Load the dataset using the deduplication setting - train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + train_dataset, _ = load_prepare_preference_datasets(self.cfg) # Verify that the dataset has been deduplicated assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" @@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase): """Verify that loading without deduplication retains duplicates.""" self.cfg.dataset_exact_deduplication = False # Load the dataset without deduplication - train_dataset, _ = load_prepare_dpo_datasets(self.cfg) + train_dataset, _ = load_prepare_preference_datasets(self.cfg) # Verify that the dataset retains duplicates assert (