Compare commits

...

16 Commits

Author SHA1 Message Date
Wing Lian
5e8c492e3c trainer refactor testing for hf#35567 2025-01-21 11:27:10 -05:00
Wing Lian
9a683536c8 upgrade accelerate also 2025-01-21 10:15:16 -05:00
Wing Lian
faa61a9c3e use official hf release for 4.48.1 2025-01-21 10:15:01 -05:00
Wing Lian
59cb36564d skip check for latest transformers 2025-01-21 10:15:01 -05:00
Wing Lian
50d4d727a0 use wip branch for expected 4.48.1 2025-01-21 10:15:00 -05:00
Wing Lian
0714a49227 move relora test so it runs in a single test thread 2025-01-21 10:15:00 -05:00
Wing Lian
b6daffb788 fix import from mv 2025-01-21 10:15:00 -05:00
Wing Lian
d487e377fa move relora to the patched tests suite 2025-01-21 10:15:00 -05:00
Wing Lian
4cc89f73f0 fix patch 2025-01-21 10:15:00 -05:00
Wing Lian
5b5ba49c46 latest fixes needed for GA in latest transformers 2025-01-21 10:15:00 -05:00
Wing Lian
49b5501fc2 unsloth incompatible with latest transformers 2025-01-21 10:15:00 -05:00
Wing Lian
23389b38b7 bump to latest transformers release 2025-01-21 10:15:00 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
20 changed files with 150 additions and 54 deletions

View File

@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000 sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200 sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing # Use batch flattening for speedups when not using sample_packing
batch_flattening: batch_flattening:

View File

@@ -13,9 +13,9 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.1 transformers @ git+https://github.com/huggingface/transformers.git@mueller-trainer-refactor
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.2.1 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.13.0 trl==0.13.0

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
@@ -103,9 +103,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets for DPO training, calling Loads one or more training or evaluation datasets for RL training using paired
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
information. Optionally, logs out debug information.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -115,7 +115,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )

View File

@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -14,15 +14,85 @@ LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """ ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if num_items_in_batch is None:
loss = loss / self.args.gradient_accumulation_steps
""" """
PATCHED_CONTEXT_CODE = """ PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_torch_xpu_available():
torch.xpu.empty_cache()
elif is_torch_mlu_available():
torch.mlu.empty_cache()
elif is_torch_musa_available():
torch.musa.empty_cache()
elif is_torch_npu_available():
torch.npu.empty_cache()
elif is_torch_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else: else:
loss = self.compute_loss(model, inputs) torch.cuda.empty_cache()
kwargs = {}
# For LOMO optimizers you need to explicitly use the learnign rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
loss = loss / self.args.gradient_accumulation_steps
""" """
ORIGINAL_LLAMA_FCLM_CODE = """ ORIGINAL_LLAMA_FCLM_CODE = """

View File

@@ -706,6 +706,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -18,10 +18,14 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples[text_column],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -30,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = [] new_input_ids = []
new_labels = [] new_labels = []
new_attention_mask = [] new_attention_mask = []
@@ -196,7 +207,13 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -386,16 +386,15 @@ class ModelLoader:
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
if self.cfg.model_config_type == "llama": # if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import ( # from axolotl.monkeypatch.trainer_grad_accum import ( # patch_forward_for_ga,
patch_flash_attention_forward, # patch_flash_attention_forward,
patch_forward_for_ga, # patch_training_step_for_ga,
patch_training_step_for_ga, # )
) #
# patch_flash_attention_forward()
patch_flash_attention_forward() # # patch_forward_for_ga()
patch_forward_for_ga() # patch_training_step_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError( raise ValueError(
@@ -1057,7 +1056,7 @@ class ModelLoader:
) )
if ( if (
hasattr(self.model, "get_input_embeddings") hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len and self.model.get_input_embeddings().num_embeddings != embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -102,9 +102,5 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -49,12 +49,7 @@ class TestModelPatches(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
@with_temp_dir @with_temp_dir
def test_mistral_multipack(self, temp_dir): def test_mistral_multipack(self, temp_dir):

View File

@@ -3,8 +3,6 @@ import unittest
import pytest import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip( @pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers" reason="Unsloth integration will be broken going into latest transformers"
@@ -13,6 +11,8 @@ class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
check_self_attn_is_patchable(), check_self_attn_is_patchable(),

View File

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -12,19 +13,22 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase): class TestPretrainLlama:
""" """
Test case for Llama models w pretraining Test case for Llama models w pretraining
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_pretrain_w_sample_packing(self, temp_dir): "sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": sample_packing,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -1,6 +1,8 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected.""" """"Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest import unittest
import pytest
from axolotl.monkeypatch.trainer_grad_accum import ( from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable, check_forward_is_patchable,
check_training_step_is_patchable, check_training_step_is_patchable,
@@ -10,6 +12,7 @@ from axolotl.monkeypatch.trainer_grad_accum import (
class TestTrainerGAIntegration(unittest.TestCase): class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests.""" """llama monkeypatch integration tests."""
@pytest.mark.skip("may not be needed for latest transformers version")
def test_train_step_patchable(self): def test_train_step_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
@@ -17,6 +20,7 @@ class TestTrainerGAIntegration(unittest.TestCase):
"HF transformers Trainer.training_step has changed and isn't patchable", "HF transformers Trainer.training_step has changed and isn't patchable",
) )
@pytest.mark.skip("may not be needed for latest transformers version")
def test_model_forward_patchable(self): def test_model_forward_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated # Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates.""" """Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication # Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates # Verify that the dataset retains duplicates
assert ( assert (