fix linting issues
This commit is contained in:
@@ -80,6 +80,8 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
def drop_long_rl_seq(
|
def drop_long_rl_seq(
|
||||||
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
|
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
|
||||||
):
|
):
|
||||||
|
result = None
|
||||||
|
|
||||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||||
if not (
|
if not (
|
||||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||||
@@ -97,47 +99,50 @@ def drop_long_rl_seq(
|
|||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||||
|
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
return (len_prompt + len_chosen) <= sequence_len and (
|
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
) <= sequence_len
|
) <= sequence_len
|
||||||
|
|
||||||
# truncate
|
# truncate
|
||||||
# If both sequences fit, return sample unchanged
|
else:
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
# If both sequences fit, return sample unchanged
|
||||||
len_prompt + len_rejected
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
) <= sequence_len:
|
len_prompt + len_rejected
|
||||||
return sample
|
) <= sequence_len:
|
||||||
|
result = sample
|
||||||
|
else:
|
||||||
|
# For truncation, we need to truncate the chosen and rejected responses
|
||||||
|
# to fit within sequence_len, but preserve the prompt
|
||||||
|
|
||||||
# For truncation, we need to truncate the chosen and rejected responses
|
# Calculate maximum response length that can fit with the prompt
|
||||||
# to fit within sequence_len, but preserve the prompt
|
max_response_len = sequence_len - len_prompt
|
||||||
|
|
||||||
# Calculate maximum response length that can fit with the prompt
|
if max_response_len <= 0:
|
||||||
max_response_len = sequence_len - len_prompt
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
result = False if handling == "drop" else sample
|
||||||
|
else:
|
||||||
|
# Truncate the chosen and rejected responses if needed
|
||||||
|
if len_chosen > max_response_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
][:max_response_len]
|
||||||
|
sample["chosen"] = tokenizer.decode(
|
||||||
|
chosen_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
if max_response_len <= 0:
|
if len_rejected > max_response_len:
|
||||||
# Prompt is already too long, we can't truncate effectively
|
# Tokenize, truncate, and decode
|
||||||
return False if handling == "drop" else sample
|
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
][:max_response_len]
|
||||||
|
sample["rejected"] = tokenizer.decode(
|
||||||
|
rejected_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate the chosen and rejected responses if needed
|
result = sample
|
||||||
if len_chosen > max_response_len:
|
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][
|
|
||||||
:max_response_len
|
|
||||||
]
|
|
||||||
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
|
|
||||||
|
|
||||||
if len_rejected > max_response_len:
|
elif rl == "kto":
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
|
||||||
"input_ids"
|
|
||||||
][:max_response_len]
|
|
||||||
sample["rejected"] = tokenizer.decode(
|
|
||||||
rejected_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
if rl == "kto":
|
|
||||||
if not (sample.get("prompt") and sample.get("completion")):
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||||
|
|
||||||
@@ -150,36 +155,39 @@ def drop_long_rl_seq(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
return (len_prompt + len_completion) <= sequence_len
|
result = (len_prompt + len_completion) <= sequence_len
|
||||||
|
|
||||||
# truncate
|
# truncate
|
||||||
# If sequence fits, return sample unchanged
|
else:
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
# If sequence fits, return sample unchanged
|
||||||
return sample
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
|
result = sample
|
||||||
|
else:
|
||||||
|
# Calculate maximum completion length that can fit with the prompt
|
||||||
|
max_completion_len = sequence_len - len_prompt
|
||||||
|
|
||||||
# Calculate maximum completion length that can fit with the prompt
|
if max_completion_len <= 0:
|
||||||
max_completion_len = sequence_len - len_prompt
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
result = False if handling == "drop" else sample
|
||||||
|
else:
|
||||||
|
# Truncate the completion if needed
|
||||||
|
if len_completion > max_completion_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
completion_tokens = tokenizer(
|
||||||
|
completion, add_special_tokens=False
|
||||||
|
)["input_ids"][:max_completion_len]
|
||||||
|
sample["completion"] = tokenizer.decode(
|
||||||
|
completion_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
if max_completion_len <= 0:
|
result = sample
|
||||||
# Prompt is already too long, we can't truncate effectively
|
|
||||||
return False if handling == "drop" else sample
|
|
||||||
|
|
||||||
# Truncate the completion if needed
|
elif rl == "grpo":
|
||||||
if len_completion > max_completion_len:
|
result = True if handling == "drop" else sample
|
||||||
# Tokenize, truncate, and decode
|
else:
|
||||||
completion_tokens = tokenizer(completion, add_special_tokens=False)[
|
raise ValueError("Unknown RL type")
|
||||||
"input_ids"
|
|
||||||
][:max_completion_len]
|
|
||||||
sample["completion"] = tokenizer.decode(
|
|
||||||
completion_tokens, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return sample
|
return result
|
||||||
|
|
||||||
if rl == "grpo":
|
|
||||||
return True if handling == "drop" else sample
|
|
||||||
|
|
||||||
raise ValueError("Unknown RL type")
|
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_preference_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ def truncate_or_drop_long_seq(
|
|||||||
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
|
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
|
||||||
"""
|
"""
|
||||||
min_sequence_len = min_sequence_len or 2
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
result = None
|
||||||
|
|
||||||
if handling == "drop":
|
if handling == "drop":
|
||||||
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
||||||
@@ -260,19 +261,16 @@ def truncate_or_drop_long_seq(
|
|||||||
|
|
||||||
# Edge case: if input_ids is empty
|
# Edge case: if input_ids is empty
|
||||||
if not input_ids:
|
if not input_ids:
|
||||||
return False if handling == "drop" else sample
|
result = False if handling == "drop" else sample
|
||||||
|
# Single example (input_ids is a list of int)
|
||||||
# Check if single example or batched by looking at the first element
|
elif isinstance(input_ids[0], int):
|
||||||
if isinstance(input_ids[0], int):
|
|
||||||
# Single example (input_ids is a list of int)
|
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
|
|
||||||
# Handle samples that are too short - always drop them
|
# Handle samples that are too short - always drop them
|
||||||
if length < min_sequence_len:
|
if length < min_sequence_len:
|
||||||
return False if handling == "drop" else sample
|
result = False if handling == "drop" else sample
|
||||||
|
|
||||||
# If truncation is enabled and the sample is too long, truncate it
|
# If truncation is enabled and the sample is too long, truncate it
|
||||||
if length > sequence_len and handling == "truncate":
|
elif length > sequence_len and handling == "truncate":
|
||||||
sample["input_ids"] = input_ids[:sequence_len]
|
sample["input_ids"] = input_ids[:sequence_len]
|
||||||
|
|
||||||
# Also truncate attention_mask if present
|
# Also truncate attention_mask if present
|
||||||
@@ -291,52 +289,58 @@ def truncate_or_drop_long_seq(
|
|||||||
if "length" in sample:
|
if "length" in sample:
|
||||||
sample["length"] = sequence_len
|
sample["length"] = sequence_len
|
||||||
|
|
||||||
return sample
|
result = sample
|
||||||
|
|
||||||
# For drop mode or if the sample doesn't exceed max length
|
# For drop mode or if the sample doesn't exceed max length
|
||||||
return (
|
else:
|
||||||
min_sequence_len <= length <= sequence_len if handling == "drop" else sample
|
result = (
|
||||||
)
|
min_sequence_len <= length <= sequence_len
|
||||||
|
if handling == "drop"
|
||||||
|
else sample
|
||||||
|
)
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
if handling == "drop":
|
else:
|
||||||
results = []
|
if handling == "drop":
|
||||||
for seq in input_ids:
|
results = []
|
||||||
length = len(seq)
|
for seq in input_ids:
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
length = len(seq)
|
||||||
return results
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
else: # truncate
|
result = results
|
||||||
# Check each sequence in the batch
|
else: # truncate
|
||||||
for i, seq in enumerate(input_ids):
|
# Check each sequence in the batch
|
||||||
length = len(seq)
|
for i, seq in enumerate(input_ids):
|
||||||
|
length = len(seq)
|
||||||
|
|
||||||
# Skip sequences that are too short
|
# Skip sequences that are too short
|
||||||
if length < min_sequence_len:
|
if length < min_sequence_len:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Truncate sequences that are too long
|
# Truncate sequences that are too long
|
||||||
if length > sequence_len:
|
if length > sequence_len:
|
||||||
input_ids[i] = seq[:sequence_len]
|
input_ids[i] = seq[:sequence_len]
|
||||||
|
|
||||||
# Also truncate attention_mask if present
|
# Also truncate attention_mask if present
|
||||||
if "attention_mask" in sample:
|
if "attention_mask" in sample:
|
||||||
sample["attention_mask"][i] = sample["attention_mask"][i][
|
sample["attention_mask"][i] = sample["attention_mask"][i][
|
||||||
:sequence_len
|
:sequence_len
|
||||||
]
|
]
|
||||||
|
|
||||||
# Also truncate labels if present
|
# Also truncate labels if present
|
||||||
if "labels" in sample:
|
if "labels" in sample:
|
||||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||||
|
|
||||||
# Also truncate position_ids if present
|
# Also truncate position_ids if present
|
||||||
if "position_ids" in sample:
|
if "position_ids" in sample:
|
||||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
sample["position_ids"][i] = sample["position_ids"][i][
|
||||||
|
:sequence_len
|
||||||
|
]
|
||||||
|
|
||||||
# Update length if present
|
# Update length if present
|
||||||
if "length" in sample:
|
if "length" in sample:
|
||||||
sample["length"][i] = sequence_len
|
sample["length"][i] = sequence_len
|
||||||
|
|
||||||
return sample
|
result = sample
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
|
|||||||
Reference in New Issue
Block a user