Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt. - Improved readability by restructuring the code and removing redundant checks. - Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
This commit is contained in:
@@ -100,44 +100,43 @@ def drop_long_rl_seq(
|
|||||||
return (len_prompt + len_chosen) <= sequence_len and (
|
return (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
) <= sequence_len
|
) <= sequence_len
|
||||||
else: # truncate
|
|
||||||
# If both sequences fit, return sample unchanged
|
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
|
||||||
len_prompt + len_rejected
|
|
||||||
) <= sequence_len:
|
|
||||||
return sample
|
|
||||||
|
|
||||||
# For truncation, we need to truncate the chosen and rejected responses
|
|
||||||
# to fit within sequence_len, but preserve the prompt
|
|
||||||
|
|
||||||
# Calculate maximum response length that can fit with the prompt
|
|
||||||
max_response_len = sequence_len - len_prompt
|
|
||||||
|
|
||||||
if max_response_len <= 0:
|
|
||||||
# Prompt is already too long, we can't truncate effectively
|
|
||||||
return False if handling == "drop" else sample
|
|
||||||
|
|
||||||
# Truncate the chosen and rejected responses if needed
|
|
||||||
if len_chosen > max_response_len:
|
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
|
||||||
"input_ids"
|
|
||||||
][:max_response_len]
|
|
||||||
sample["chosen"] = tokenizer.decode(
|
|
||||||
chosen_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if len_rejected > max_response_len:
|
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
|
||||||
"input_ids"
|
|
||||||
][:max_response_len]
|
|
||||||
sample["rejected"] = tokenizer.decode(
|
|
||||||
rejected_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# truncate
|
||||||
|
# If both sequences fit, return sample unchanged
|
||||||
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
|
len_prompt + len_rejected
|
||||||
|
) <= sequence_len:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
# For truncation, we need to truncate the chosen and rejected responses
|
||||||
|
# to fit within sequence_len, but preserve the prompt
|
||||||
|
|
||||||
|
# Calculate maximum response length that can fit with the prompt
|
||||||
|
max_response_len = sequence_len - len_prompt
|
||||||
|
|
||||||
|
if max_response_len <= 0:
|
||||||
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Truncate the chosen and rejected responses if needed
|
||||||
|
if len_chosen > max_response_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][
|
||||||
|
:max_response_len
|
||||||
|
]
|
||||||
|
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
if len_rejected > max_response_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
][:max_response_len]
|
||||||
|
sample["rejected"] = tokenizer.decode(
|
||||||
|
rejected_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
if rl == "kto":
|
if rl == "kto":
|
||||||
if not (sample.get("prompt") and sample.get("completion")):
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||||
@@ -152,30 +151,31 @@ def drop_long_rl_seq(
|
|||||||
|
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
return (len_prompt + len_completion) <= sequence_len
|
return (len_prompt + len_completion) <= sequence_len
|
||||||
else: # truncate
|
|
||||||
# If sequence fits, return sample unchanged
|
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
|
||||||
return sample
|
|
||||||
|
|
||||||
# Calculate maximum completion length that can fit with the prompt
|
|
||||||
max_completion_len = sequence_len - len_prompt
|
|
||||||
|
|
||||||
if max_completion_len <= 0:
|
|
||||||
# Prompt is already too long, we can't truncate effectively
|
|
||||||
return False if handling == "drop" else sample
|
|
||||||
|
|
||||||
# Truncate the completion if needed
|
|
||||||
if len_completion > max_completion_len:
|
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
completion_tokens = tokenizer(completion, add_special_tokens=False)[
|
|
||||||
"input_ids"
|
|
||||||
][:max_completion_len]
|
|
||||||
sample["completion"] = tokenizer.decode(
|
|
||||||
completion_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# truncate
|
||||||
|
# If sequence fits, return sample unchanged
|
||||||
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
# Calculate maximum completion length that can fit with the prompt
|
||||||
|
max_completion_len = sequence_len - len_prompt
|
||||||
|
|
||||||
|
if max_completion_len <= 0:
|
||||||
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Truncate the completion if needed
|
||||||
|
if len_completion > max_completion_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
completion_tokens = tokenizer(completion, add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
][:max_completion_len]
|
||||||
|
sample["completion"] = tokenizer.decode(
|
||||||
|
completion_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
if rl == "grpo":
|
if rl == "grpo":
|
||||||
return True if handling == "drop" else sample
|
return True if handling == "drop" else sample
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user