Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling
This commit is contained in:
@@ -122,18 +122,6 @@ def _map_dataset(
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
def drop_long_rl_seq(
|
|
||||||
sample,
|
|
||||||
rl,
|
|
||||||
tokenizer,
|
|
||||||
sequence_len,
|
|
||||||
handling="drop", # Use the default handling mode
|
|
||||||
):
|
|
||||||
result = None
|
|
||||||
|
|
||||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
|
||||||
=======
|
|
||||||
def _drop_long_sequences(
|
def _drop_long_sequences(
|
||||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -152,7 +140,6 @@ def _drop_long_sequences(
|
|||||||
ValueError: If required keys are missing or RL type is unknown.
|
ValueError: If required keys are missing or RL type is unknown.
|
||||||
"""
|
"""
|
||||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||||
>>>>>>> origin/main
|
|
||||||
if not (
|
if not (
|
||||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||||
):
|
):
|
||||||
@@ -169,6 +156,7 @@ def _drop_long_sequences(
|
|||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||||
|
|
||||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||||
|
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||||
if handling == "truncate":
|
if handling == "truncate":
|
||||||
# If both sequences fit, return sample unchanged
|
# If both sequences fit, return sample unchanged
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
@@ -225,6 +213,7 @@ def _drop_long_sequences(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Truncate first
|
# Truncate first
|
||||||
|
handling = sample.get("sequence_len_overflow_handling", "drop")
|
||||||
if handling == "truncate":
|
if handling == "truncate":
|
||||||
# If sequence fits, return sample unchanged
|
# If sequence fits, return sample unchanged
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
@@ -255,17 +244,14 @@ def _drop_long_sequences(
|
|||||||
result = (len_prompt + len_completion) <= sequence_len
|
result = (len_prompt + len_completion) <= sequence_len
|
||||||
|
|
||||||
elif rl == RLType.GRPO:
|
elif rl == RLType.GRPO:
|
||||||
# GRPO doesn't involve sequence length checks in the same way?
|
# For GRPO always keep
|
||||||
# The original code returned True for drop. What should it return for truncate?
|
result = True
|
||||||
# Let's assume for now it always passes.
|
|
||||||
result = sample if handling == "truncate" else True
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
||||||
"""
|
"""
|
||||||
@@ -310,10 +296,8 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
config_dataset, use_auth_token, streaming=False
|
config_dataset, use_auth_token, streaming=False
|
||||||
)
|
)
|
||||||
split_datasets.append(ds)
|
split_datasets.append(ds)
|
||||||
=======
|
|
||||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||||
"""Load and process dataset split for RL training.
|
"""Load and process dataset split for RL training.
|
||||||
>>>>>>> origin/main
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Configuration object containing dataset settings.
|
cfg: Configuration object containing dataset settings.
|
||||||
@@ -325,7 +309,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||||
split_datasets: list[Dataset | DatasetDict] = []
|
split_datasets: list[Dataset | DatasetDict] = []
|
||||||
|
|
||||||
<<<<<<< HEAD
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if isinstance(ds_transform_fn, tuple):
|
if isinstance(ds_transform_fn, tuple):
|
||||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
@@ -461,11 +444,9 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication:
|
||||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||||
=======
|
|
||||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||||
>>>>>>> origin/main
|
|
||||||
)
|
)
|
||||||
split_datasets.append(dataset)
|
split_datasets.append(dataset)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user