rename references to dpo dataset prep to pref data (#2258)
This commit is contained in:
@@ -11,7 +11,7 @@ from datasets import Dataset
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
@@ -103,9 +103,9 @@ def load_preference_datasets(
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for DPO training, calling
|
||||
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
|
||||
information.
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -115,7 +115,7 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
|
||||
@@ -115,7 +115,7 @@ def drop_long_rl_seq(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_dpo_datasets(cfg):
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||
|
||||
@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
@@ -12,7 +12,7 @@ from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# Load the dataset using the deduplication setting
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
"""Verify that loading without deduplication retains duplicates."""
|
||||
self.cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user