From 746c03b097e444287f4df1a24248af6e7fae115a Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Tue, 12 Aug 2025 20:53:28 +0200 Subject: [PATCH] Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling --- src/axolotl/utils/data/rl.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 50f3b67d0..b95f86ab2 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -122,18 +122,6 @@ def _map_dataset( return dataset -<<<<<<< HEAD -def drop_long_rl_seq( - sample, - rl, - tokenizer, - sequence_len, - handling="drop", # Use the default handling mode -): - result = None - - if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): -======= def _drop_long_sequences( sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int ) -> bool: @@ -152,7 +140,6 @@ def _drop_long_sequences( ValueError: If required keys are missing or RL type is unknown. """ if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}: ->>>>>>> origin/main if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -169,6 +156,7 @@ def _drop_long_sequences( len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) # Truncate first, then drop if still invalid (although truncate should handle it) + handling = sample.get("sequence_len_overflow_handling", "drop") if handling == "truncate": # If both sequences fit, return sample unchanged if (len_prompt + len_chosen) <= sequence_len and ( @@ -225,6 +213,7 @@ def _drop_long_sequences( ) # Truncate first + handling = sample.get("sequence_len_overflow_handling", "drop") if handling == "truncate": # If sequence fits, return sample unchanged if (len_prompt + len_completion) <= sequence_len: @@ -255,17 +244,14 @@ def _drop_long_sequences( result = (len_prompt + len_completion) <= sequence_len elif rl == RLType.GRPO: - # GRPO doesn't involve sequence length checks in the same way? - # The original code returned True for drop. What should it return for truncate? - # Let's assume for now it always passes. - result = sample if handling == "truncate" else True + # For GRPO always keep + result = True else: raise ValueError("Unknown RL type") return result -<<<<<<< HEAD def load_prepare_preference_datasets(cfg): def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len): """ @@ -310,10 +296,8 @@ def load_prepare_preference_datasets(cfg): config_dataset, use_auth_token, streaming=False ) split_datasets.append(ds) -======= def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: """Load and process dataset split for RL training. ->>>>>>> origin/main Args: cfg: Configuration object containing dataset settings. @@ -325,7 +309,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets split_datasets: list[Dataset | DatasetDict] = [] -<<<<<<< HEAD map_kwargs = {} if isinstance(ds_transform_fn, tuple): ds_transform_fn, map_kwargs = ds_transform_fn @@ -461,11 +444,9 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: if cfg.dataset_exact_deduplication: train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( train_dataset=train_dataset, eval_dataset=eval_dataset -======= for dataset_config in datasets_with_name_generator(datasets_configs): dataset: Dataset | DatasetDict = load_dataset_with_config( dataset_config, cfg.hf_use_auth_token, streaming=False ->>>>>>> origin/main ) split_datasets.append(dataset)