Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling
This commit is contained in:
@@ -122,18 +122,6 @@ def _map_dataset(
|
||||
return dataset
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def drop_long_rl_seq(
|
||||
sample,
|
||||
rl,
|
||||
tokenizer,
|
||||
sequence_len,
|
||||
handling="drop", # Use the default handling mode
|
||||
):
|
||||
result = None
|
||||
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
=======
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
@@ -152,7 +140,6 @@ def _drop_long_sequences(
|
||||
ValueError: If required keys are missing or RL type is unknown.
|
||||
"""
|
||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||
>>>>>>> origin/main
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
@@ -169,6 +156,7 @@ def _drop_long_sequences(
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
@@ -225,6 +213,7 @@ def _drop_long_sequences(
|
||||
)
|
||||
|
||||
# Truncate first
|
||||
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
@@ -255,17 +244,14 @@ def _drop_long_sequences(
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
elif rl == RLType.GRPO:
|
||||
# GRPO doesn't involve sequence length checks in the same way?
|
||||
# The original code returned True for drop. What should it return for truncate?
|
||||
# Let's assume for now it always passes.
|
||||
result = sample if handling == "truncate" else True
|
||||
# For GRPO always keep
|
||||
result = True
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
||||
"""
|
||||
@@ -310,10 +296,8 @@ def load_prepare_preference_datasets(cfg):
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
=======
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
>>>>>>> origin/main
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing dataset settings.
|
||||
@@ -325,7 +309,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
split_datasets: list[Dataset | DatasetDict] = []
|
||||
|
||||
<<<<<<< HEAD
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
@@ -461,11 +444,9 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
=======
|
||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||
>>>>>>> origin/main
|
||||
)
|
||||
split_datasets.append(dataset)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user