fix linting issues

This commit is contained in:
mhenrhcsen
2025-05-12 14:46:57 +02:00
parent f07db4f853
commit be3c6bbd85
2 changed files with 113 additions and 101 deletions

View File

@@ -80,6 +80,8 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq( def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
): ):
result = None
if rl in ("dpo", "ipo", "orpo", "simpo"): if rl in ("dpo", "ipo", "orpo", "simpo"):
if not ( if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected") sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
@@ -97,17 +99,18 @@ def drop_long_rl_seq(
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
if handling == "drop": if handling == "drop":
return (len_prompt + len_chosen) <= sequence_len and ( result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len ) <= sequence_len
# truncate # truncate
else:
# If both sequences fit, return sample unchanged # If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and ( if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len: ) <= sequence_len:
return sample result = sample
else:
# For truncation, we need to truncate the chosen and rejected responses # For truncation, we need to truncate the chosen and rejected responses
# to fit within sequence_len, but preserve the prompt # to fit within sequence_len, but preserve the prompt
@@ -116,15 +119,17 @@ def drop_long_rl_seq(
if max_response_len <= 0: if max_response_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt is already too long, we can't truncate effectively
return False if handling == "drop" else sample result = False if handling == "drop" else sample
else:
# Truncate the chosen and rejected responses if needed # Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len: if len_chosen > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][ chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
:max_response_len "input_ids"
] ][:max_response_len]
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len: if len_rejected > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
@@ -135,9 +140,9 @@ def drop_long_rl_seq(
rejected_tokens, skip_special_tokens=True rejected_tokens, skip_special_tokens=True
) )
return sample result = sample
if rl == "kto": elif rl == "kto":
if not (sample.get("prompt") and sample.get("completion")): if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets") raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -150,37 +155,40 @@ def drop_long_rl_seq(
) )
if handling == "drop": if handling == "drop":
return (len_prompt + len_completion) <= sequence_len result = (len_prompt + len_completion) <= sequence_len
# truncate # truncate
else:
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
return sample result = sample
else:
# Calculate maximum completion length that can fit with the prompt # Calculate maximum completion length that can fit with the prompt
max_completion_len = sequence_len - len_prompt max_completion_len = sequence_len - len_prompt
if max_completion_len <= 0: if max_completion_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt is already too long, we can't truncate effectively
return False if handling == "drop" else sample result = False if handling == "drop" else sample
else:
# Truncate the completion if needed # Truncate the completion if needed
if len_completion > max_completion_len: if len_completion > max_completion_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
completion_tokens = tokenizer(completion, add_special_tokens=False)[ completion_tokens = tokenizer(
"input_ids" completion, add_special_tokens=False
][:max_completion_len] )["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode( sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True completion_tokens, skip_special_tokens=True
) )
return sample result = sample
if rl == "grpo":
return True if handling == "drop" else sample
elif rl == "grpo":
result = True if handling == "drop" else sample
else:
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
return result
def load_prepare_preference_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):

View File

@@ -252,6 +252,7 @@ def truncate_or_drop_long_seq(
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode). Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
""" """
min_sequence_len = min_sequence_len or 2 min_sequence_len = min_sequence_len or 2
result = None
if handling == "drop": if handling == "drop":
return drop_long_seq(sample, sequence_len, min_sequence_len) return drop_long_seq(sample, sequence_len, min_sequence_len)
@@ -260,19 +261,16 @@ def truncate_or_drop_long_seq(
# Edge case: if input_ids is empty # Edge case: if input_ids is empty
if not input_ids: if not input_ids:
return False if handling == "drop" else sample result = False if handling == "drop" else sample
# Check if single example or batched by looking at the first element
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int) # Single example (input_ids is a list of int)
elif isinstance(input_ids[0], int):
length = len(input_ids) length = len(input_ids)
# Handle samples that are too short - always drop them # Handle samples that are too short - always drop them
if length < min_sequence_len: if length < min_sequence_len:
return False if handling == "drop" else sample result = False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it # If truncation is enabled and the sample is too long, truncate it
if length > sequence_len and handling == "truncate": elif length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len] sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present # Also truncate attention_mask if present
@@ -291,20 +289,22 @@ def truncate_or_drop_long_seq(
if "length" in sample: if "length" in sample:
sample["length"] = sequence_len sample["length"] = sequence_len
return sample result = sample
# For drop mode or if the sample doesn't exceed max length # For drop mode or if the sample doesn't exceed max length
return ( else:
min_sequence_len <= length <= sequence_len if handling == "drop" else sample result = (
min_sequence_len <= length <= sequence_len
if handling == "drop"
else sample
) )
# Batched (input_ids is a list of lists) # Batched (input_ids is a list of lists)
else:
if handling == "drop": if handling == "drop":
results = [] results = []
for seq in input_ids: for seq in input_ids:
length = len(seq) length = len(seq)
results.append(min_sequence_len <= length <= sequence_len) results.append(min_sequence_len <= length <= sequence_len)
return results result = results
else: # truncate else: # truncate
# Check each sequence in the batch # Check each sequence in the batch
for i, seq in enumerate(input_ids): for i, seq in enumerate(input_ids):
@@ -330,13 +330,17 @@ def truncate_or_drop_long_seq(
# Also truncate position_ids if present # Also truncate position_ids if present
if "position_ids" in sample: if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] sample["position_ids"][i] = sample["position_ids"][i][
:sequence_len
]
# Update length if present # Update length if present
if "length" in sample: if "length" in sample:
sample["length"][i] = sequence_len sample["length"][i] = sequence_len
return sample result = sample
return result
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):