Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling

This commit is contained in:
mhenrhcsen
2025-08-12 20:53:28 +02:00
parent 47b3fe8af3
commit 746c03b097

View File

@@ -122,18 +122,6 @@ def _map_dataset(
return dataset
<<<<<<< HEAD
def drop_long_rl_seq(
sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
):
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
=======
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
@@ -152,7 +140,6 @@ def _drop_long_sequences(
ValueError: If required keys are missing or RL type is unknown.
"""
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
>>>>>>> origin/main
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
@@ -169,6 +156,7 @@ def _drop_long_sequences(
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
# Truncate first, then drop if still invalid (although truncate should handle it)
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
@@ -225,6 +213,7 @@ def _drop_long_sequences(
)
# Truncate first
handling = sample.get("sequence_len_overflow_handling", "drop")
if handling == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
@@ -255,17 +244,14 @@ def _drop_long_sequences(
result = (len_prompt + len_completion) <= sequence_len
elif rl == RLType.GRPO:
# GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
# For GRPO always keep
result = True
else:
raise ValueError("Unknown RL type")
return result
<<<<<<< HEAD
def load_prepare_preference_datasets(cfg):
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
"""
@@ -310,10 +296,8 @@ def load_prepare_preference_datasets(cfg):
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
=======
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
>>>>>>> origin/main
Args:
cfg: Configuration object containing dataset settings.
@@ -325,7 +309,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = []
<<<<<<< HEAD
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
@@ -461,11 +444,9 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
=======
for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False
>>>>>>> origin/main
)
split_datasets.append(dataset)