diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index e1d5f8c9f..38dd08963 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -80,6 +80,8 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): def drop_long_rl_seq( sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name ): + result = None + if rl in ("dpo", "ipo", "orpo", "simpo"): if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") @@ -97,47 +99,50 @@ def drop_long_rl_seq( len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) if handling == "drop": - return (len_prompt + len_chosen) <= sequence_len and ( + result = (len_prompt + len_chosen) <= sequence_len and ( len_prompt + len_rejected ) <= sequence_len # truncate - # If both sequences fit, return sample unchanged - if (len_prompt + len_chosen) <= sequence_len and ( - len_prompt + len_rejected - ) <= sequence_len: - return sample + else: + # If both sequences fit, return sample unchanged + if (len_prompt + len_chosen) <= sequence_len and ( + len_prompt + len_rejected + ) <= sequence_len: + result = sample + else: + # For truncation, we need to truncate the chosen and rejected responses + # to fit within sequence_len, but preserve the prompt - # For truncation, we need to truncate the chosen and rejected responses - # to fit within sequence_len, but preserve the prompt + # Calculate maximum response length that can fit with the prompt + max_response_len = sequence_len - len_prompt - # Calculate maximum response length that can fit with the prompt - max_response_len = sequence_len - len_prompt + if max_response_len <= 0: + # Prompt is already too long, we can't truncate effectively + result = False if handling == "drop" else sample + else: + # Truncate the chosen and rejected responses if needed + if len_chosen > max_response_len: + # Tokenize, truncate, and decode + chosen_tokens = tokenizer(chosen, add_special_tokens=False)[ + "input_ids" + ][:max_response_len] + sample["chosen"] = tokenizer.decode( + chosen_tokens, skip_special_tokens=True + ) - if max_response_len <= 0: - # Prompt is already too long, we can't truncate effectively - return False if handling == "drop" else sample + if len_rejected > max_response_len: + # Tokenize, truncate, and decode + rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ + "input_ids" + ][:max_response_len] + sample["rejected"] = tokenizer.decode( + rejected_tokens, skip_special_tokens=True + ) - # Truncate the chosen and rejected responses if needed - if len_chosen > max_response_len: - # Tokenize, truncate, and decode - chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][ - :max_response_len - ] - sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) + result = sample - if len_rejected > max_response_len: - # Tokenize, truncate, and decode - rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ - "input_ids" - ][:max_response_len] - sample["rejected"] = tokenizer.decode( - rejected_tokens, skip_special_tokens=True - ) - - return sample - - if rl == "kto": + elif rl == "kto": if not (sample.get("prompt") and sample.get("completion")): raise ValueError("Prompt and completion keys are required for KTO datasets") @@ -150,36 +155,39 @@ def drop_long_rl_seq( ) if handling == "drop": - return (len_prompt + len_completion) <= sequence_len + result = (len_prompt + len_completion) <= sequence_len # truncate - # If sequence fits, return sample unchanged - if (len_prompt + len_completion) <= sequence_len: - return sample + else: + # If sequence fits, return sample unchanged + if (len_prompt + len_completion) <= sequence_len: + result = sample + else: + # Calculate maximum completion length that can fit with the prompt + max_completion_len = sequence_len - len_prompt - # Calculate maximum completion length that can fit with the prompt - max_completion_len = sequence_len - len_prompt + if max_completion_len <= 0: + # Prompt is already too long, we can't truncate effectively + result = False if handling == "drop" else sample + else: + # Truncate the completion if needed + if len_completion > max_completion_len: + # Tokenize, truncate, and decode + completion_tokens = tokenizer( + completion, add_special_tokens=False + )["input_ids"][:max_completion_len] + sample["completion"] = tokenizer.decode( + completion_tokens, skip_special_tokens=True + ) - if max_completion_len <= 0: - # Prompt is already too long, we can't truncate effectively - return False if handling == "drop" else sample + result = sample - # Truncate the completion if needed - if len_completion > max_completion_len: - # Tokenize, truncate, and decode - completion_tokens = tokenizer(completion, add_special_tokens=False)[ - "input_ids" - ][:max_completion_len] - sample["completion"] = tokenizer.decode( - completion_tokens, skip_special_tokens=True - ) + elif rl == "grpo": + result = True if handling == "drop" else sample + else: + raise ValueError("Unknown RL type") - return sample - - if rl == "grpo": - return True if handling == "drop" else sample - - raise ValueError("Unknown RL type") + return result def load_prepare_preference_datasets(cfg): diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f9c134e82..8996923a0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -252,6 +252,7 @@ def truncate_or_drop_long_seq( Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode). """ min_sequence_len = min_sequence_len or 2 + result = None if handling == "drop": return drop_long_seq(sample, sequence_len, min_sequence_len) @@ -260,19 +261,16 @@ def truncate_or_drop_long_seq( # Edge case: if input_ids is empty if not input_ids: - return False if handling == "drop" else sample - - # Check if single example or batched by looking at the first element - if isinstance(input_ids[0], int): - # Single example (input_ids is a list of int) + result = False if handling == "drop" else sample + # Single example (input_ids is a list of int) + elif isinstance(input_ids[0], int): length = len(input_ids) # Handle samples that are too short - always drop them if length < min_sequence_len: - return False if handling == "drop" else sample - + result = False if handling == "drop" else sample # If truncation is enabled and the sample is too long, truncate it - if length > sequence_len and handling == "truncate": + elif length > sequence_len and handling == "truncate": sample["input_ids"] = input_ids[:sequence_len] # Also truncate attention_mask if present @@ -291,52 +289,58 @@ def truncate_or_drop_long_seq( if "length" in sample: sample["length"] = sequence_len - return sample - + result = sample # For drop mode or if the sample doesn't exceed max length - return ( - min_sequence_len <= length <= sequence_len if handling == "drop" else sample - ) - + else: + result = ( + min_sequence_len <= length <= sequence_len + if handling == "drop" + else sample + ) # Batched (input_ids is a list of lists) - if handling == "drop": - results = [] - for seq in input_ids: - length = len(seq) - results.append(min_sequence_len <= length <= sequence_len) - return results - else: # truncate - # Check each sequence in the batch - for i, seq in enumerate(input_ids): - length = len(seq) + else: + if handling == "drop": + results = [] + for seq in input_ids: + length = len(seq) + results.append(min_sequence_len <= length <= sequence_len) + result = results + else: # truncate + # Check each sequence in the batch + for i, seq in enumerate(input_ids): + length = len(seq) - # Skip sequences that are too short - if length < min_sequence_len: - continue + # Skip sequences that are too short + if length < min_sequence_len: + continue - # Truncate sequences that are too long - if length > sequence_len: - input_ids[i] = seq[:sequence_len] + # Truncate sequences that are too long + if length > sequence_len: + input_ids[i] = seq[:sequence_len] - # Also truncate attention_mask if present - if "attention_mask" in sample: - sample["attention_mask"][i] = sample["attention_mask"][i][ - :sequence_len - ] + # Also truncate attention_mask if present + if "attention_mask" in sample: + sample["attention_mask"][i] = sample["attention_mask"][i][ + :sequence_len + ] - # Also truncate labels if present - if "labels" in sample: - sample["labels"][i] = sample["labels"][i][:sequence_len] + # Also truncate labels if present + if "labels" in sample: + sample["labels"][i] = sample["labels"][i][:sequence_len] - # Also truncate position_ids if present - if "position_ids" in sample: - sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] + # Also truncate position_ids if present + if "position_ids" in sample: + sample["position_ids"][i] = sample["position_ids"][i][ + :sequence_len + ] - # Update length if present - if "length" in sample: - sample["length"][i] = sequence_len + # Update length if present + if "length" in sample: + sample["length"][i] = sequence_len - return sample + result = sample + + return result def process_datasets_for_packing(cfg, train_dataset, eval_dataset):