Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt. - Improved readability by restructuring the code and removing redundant checks. - Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
This commit is contained in:
@@ -100,7 +100,8 @@ def drop_long_rl_seq(
|
|||||||
return (len_prompt + len_chosen) <= sequence_len and (
|
return (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
) <= sequence_len
|
) <= sequence_len
|
||||||
else: # truncate
|
|
||||||
|
# truncate
|
||||||
# If both sequences fit, return sample unchanged
|
# If both sequences fit, return sample unchanged
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
@@ -120,12 +121,10 @@ def drop_long_rl_seq(
|
|||||||
# Truncate the chosen and rejected responses if needed
|
# Truncate the chosen and rejected responses if needed
|
||||||
if len_chosen > max_response_len:
|
if len_chosen > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
# Tokenize, truncate, and decode
|
||||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][
|
||||||
"input_ids"
|
:max_response_len
|
||||||
][:max_response_len]
|
]
|
||||||
sample["chosen"] = tokenizer.decode(
|
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
|
||||||
chosen_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if len_rejected > max_response_len:
|
if len_rejected > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
# Tokenize, truncate, and decode
|
||||||
@@ -152,7 +151,8 @@ def drop_long_rl_seq(
|
|||||||
|
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
return (len_prompt + len_completion) <= sequence_len
|
return (len_prompt + len_completion) <= sequence_len
|
||||||
else: # truncate
|
|
||||||
|
# truncate
|
||||||
# If sequence fits, return sample unchanged
|
# If sequence fits, return sample unchanged
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
Reference in New Issue
Block a user