Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling

This commit is contained in:
mhenrhcsen
2025-08-12 20:53:28 +02:00
parent 47b3fe8af3
commit 746c03b097

View File

@@ -122,18 +122,6 @@ def _map_dataset(
return dataset return dataset
<<<<<<< HEAD
def drop_long_rl_seq(
sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
):
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
=======
def _drop_long_sequences( def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool: ) -> bool:
@@ -152,7 +140,6 @@ def _drop_long_sequences(
ValueError: If required keys are missing or RL type is unknown. ValueError: If required keys are missing or RL type is unknown.
""" """
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}: if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
>>>>>>> origin/main
if not ( if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected") sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
): ):
@@ -169,6 +156,7 @@ def _drop_long_sequences(
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
# Truncate first, then drop if still invalid (although truncate should handle it) # Truncate first, then drop if still invalid (although truncate should handle it)
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate": if handling == "truncate":
# If both sequences fit, return sample unchanged # If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and ( if (len_prompt + len_chosen) <= sequence_len and (
@@ -225,6 +213,7 @@ def _drop_long_sequences(
) )
# Truncate first # Truncate first
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate": if handling == "truncate":
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
@@ -255,17 +244,14 @@ def _drop_long_sequences(
result = (len_prompt + len_completion) <= sequence_len result = (len_prompt + len_completion) <= sequence_len
elif rl == RLType.GRPO: elif rl == RLType.GRPO:
# GRPO doesn't involve sequence length checks in the same way? # For GRPO always keep
# The original code returned True for drop. What should it return for truncate? result = True
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
else: else:
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
return result return result
<<<<<<< HEAD
def load_prepare_preference_datasets(cfg): def load_prepare_preference_datasets(cfg):
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len): def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
""" """
@@ -310,10 +296,8 @@ def load_prepare_preference_datasets(cfg):
config_dataset, use_auth_token, streaming=False config_dataset, use_auth_token, streaming=False
) )
split_datasets.append(ds) split_datasets.append(ds)
=======
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training. """Load and process dataset split for RL training.
>>>>>>> origin/main
Args: Args:
cfg: Configuration object containing dataset settings. cfg: Configuration object containing dataset settings.
@@ -325,7 +309,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = [] split_datasets: list[Dataset | DatasetDict] = []
<<<<<<< HEAD
map_kwargs = {} map_kwargs = {}
if isinstance(ds_transform_fn, tuple): if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn ds_transform_fn, map_kwargs = ds_transform_fn
@@ -461,11 +444,9 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
if cfg.dataset_exact_deduplication: if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset train_dataset=train_dataset, eval_dataset=eval_dataset
=======
for dataset_config in datasets_with_name_generator(datasets_configs): for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config( dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False dataset_config, cfg.hf_use_auth_token, streaming=False
>>>>>>> origin/main
) )
split_datasets.append(dataset) split_datasets.append(dataset)