Refactor truncation logic in drop_long_rl_seq function

- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt.
- Improved readability by restructuring the code and removing redundant checks.
- Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
This commit is contained in:
mhenrhcsen
2025-05-12 14:40:10 +02:00
parent 17a5838d38
commit f07db4f853

View File

@@ -100,7 +100,8 @@ def drop_long_rl_seq(
return (len_prompt + len_chosen) <= sequence_len and ( return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len ) <= sequence_len
else: # truncate
# truncate
# If both sequences fit, return sample unchanged # If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and ( if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
@@ -120,12 +121,10 @@ def drop_long_rl_seq(
# Truncate the chosen and rejected responses if needed # Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len: if len_chosen > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[ chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][
"input_ids" :max_response_len
][:max_response_len] ]
sample["chosen"] = tokenizer.decode( sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len: if len_rejected > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
@@ -152,7 +151,8 @@ def drop_long_rl_seq(
if handling == "drop": if handling == "drop":
return (len_prompt + len_completion) <= sequence_len return (len_prompt + len_completion) <= sequence_len
else: # truncate
# truncate
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
return sample return sample