diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 84a47e85b..e1d5f8c9f 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -100,44 +100,43 @@ def drop_long_rl_seq( return (len_prompt + len_chosen) <= sequence_len and ( len_prompt + len_rejected ) <= sequence_len - else: # truncate - # If both sequences fit, return sample unchanged - if (len_prompt + len_chosen) <= sequence_len and ( - len_prompt + len_rejected - ) <= sequence_len: - return sample - - # For truncation, we need to truncate the chosen and rejected responses - # to fit within sequence_len, but preserve the prompt - - # Calculate maximum response length that can fit with the prompt - max_response_len = sequence_len - len_prompt - - if max_response_len <= 0: - # Prompt is already too long, we can't truncate effectively - return False if handling == "drop" else sample - - # Truncate the chosen and rejected responses if needed - if len_chosen > max_response_len: - # Tokenize, truncate, and decode - chosen_tokens = tokenizer(chosen, add_special_tokens=False)[ - "input_ids" - ][:max_response_len] - sample["chosen"] = tokenizer.decode( - chosen_tokens, skip_special_tokens=True - ) - - if len_rejected > max_response_len: - # Tokenize, truncate, and decode - rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ - "input_ids" - ][:max_response_len] - sample["rejected"] = tokenizer.decode( - rejected_tokens, skip_special_tokens=True - ) + # truncate + # If both sequences fit, return sample unchanged + if (len_prompt + len_chosen) <= sequence_len and ( + len_prompt + len_rejected + ) <= sequence_len: return sample + # For truncation, we need to truncate the chosen and rejected responses + # to fit within sequence_len, but preserve the prompt + + # Calculate maximum response length that can fit with the prompt + max_response_len = sequence_len - len_prompt + + if max_response_len <= 0: + # Prompt is already too long, we can't truncate effectively + return False if handling == "drop" else sample + + # Truncate the chosen and rejected responses if needed + if len_chosen > max_response_len: + # Tokenize, truncate, and decode + chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][ + :max_response_len + ] + sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) + + if len_rejected > max_response_len: + # Tokenize, truncate, and decode + rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ + "input_ids" + ][:max_response_len] + sample["rejected"] = tokenizer.decode( + rejected_tokens, skip_special_tokens=True + ) + + return sample + if rl == "kto": if not (sample.get("prompt") and sample.get("completion")): raise ValueError("Prompt and completion keys are required for KTO datasets") @@ -152,30 +151,31 @@ def drop_long_rl_seq( if handling == "drop": return (len_prompt + len_completion) <= sequence_len - else: # truncate - # If sequence fits, return sample unchanged - if (len_prompt + len_completion) <= sequence_len: - return sample - - # Calculate maximum completion length that can fit with the prompt - max_completion_len = sequence_len - len_prompt - - if max_completion_len <= 0: - # Prompt is already too long, we can't truncate effectively - return False if handling == "drop" else sample - - # Truncate the completion if needed - if len_completion > max_completion_len: - # Tokenize, truncate, and decode - completion_tokens = tokenizer(completion, add_special_tokens=False)[ - "input_ids" - ][:max_completion_len] - sample["completion"] = tokenizer.decode( - completion_tokens, skip_special_tokens=True - ) + # truncate + # If sequence fits, return sample unchanged + if (len_prompt + len_completion) <= sequence_len: return sample + # Calculate maximum completion length that can fit with the prompt + max_completion_len = sequence_len - len_prompt + + if max_completion_len <= 0: + # Prompt is already too long, we can't truncate effectively + return False if handling == "drop" else sample + + # Truncate the completion if needed + if len_completion > max_completion_len: + # Tokenize, truncate, and decode + completion_tokens = tokenizer(completion, add_special_tokens=False)[ + "input_ids" + ][:max_completion_len] + sample["completion"] = tokenizer.decode( + completion_tokens, skip_special_tokens=True + ) + + return sample + if rl == "grpo": return True if handling == "drop" else sample