This commit is contained in:
mhenrhcsen
2025-05-12 14:36:43 +02:00
parent 9f68918f13
commit 17a5838d38
5 changed files with 71 additions and 43 deletions

View File

@@ -332,7 +332,7 @@ dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048 # The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens) # How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens)
excess_token_handling: drop excess_token_handling: drop
# Pad inputs so each step uses constant sized buffers # Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently

View File

@@ -106,28 +106,36 @@ def drop_long_rl_seq(
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len: ) <= sequence_len:
return sample return sample
# For truncation, we need to truncate the chosen and rejected responses # For truncation, we need to truncate the chosen and rejected responses
# to fit within sequence_len, but preserve the prompt # to fit within sequence_len, but preserve the prompt
# Calculate maximum response length that can fit with the prompt # Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt max_response_len = sequence_len - len_prompt
if max_response_len <= 0: if max_response_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt is already too long, we can't truncate effectively
return False if handling == "drop" else sample return False if handling == "drop" else sample
# Truncate the chosen and rejected responses if needed # Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len: if len_chosen > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][:max_response_len] chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) "input_ids"
][:max_response_len]
sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len: if len_rejected > max_response_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
rejected_tokens = tokenizer(rejected, add_special_tokens=False)["input_ids"][:max_response_len] rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
sample["rejected"] = tokenizer.decode(rejected_tokens, skip_special_tokens=True) "input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
return sample return sample
if rl == "kto": if rl == "kto":
@@ -148,20 +156,24 @@ def drop_long_rl_seq(
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
return sample return sample
# Calculate maximum completion length that can fit with the prompt # Calculate maximum completion length that can fit with the prompt
max_completion_len = sequence_len - len_prompt max_completion_len = sequence_len - len_prompt
if max_completion_len <= 0: if max_completion_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt is already too long, we can't truncate effectively
return False if handling == "drop" else sample return False if handling == "drop" else sample
# Truncate the completion if needed # Truncate the completion if needed
if len_completion > max_completion_len: if len_completion > max_completion_len:
# Tokenize, truncate, and decode # Tokenize, truncate, and decode
completion_tokens = tokenizer(completion, add_special_tokens=False)["input_ids"][:max_completion_len] completion_tokens = tokenizer(completion, add_special_tokens=False)[
sample["completion"] = tokenizer.decode(completion_tokens, skip_special_tokens=True) "input_ids"
][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
return sample return sample
if rl == "grpo": if rl == "grpo":
@@ -223,7 +235,7 @@ def load_prepare_preference_datasets(cfg):
) )
prior_len = len(split_datasets[i]) prior_len = len(split_datasets[i])
# Use filter for drop mode and map for truncate mode # Use filter for drop mode and map for truncate mode
handling = cfg.get("excess_token_handling", "drop") handling = cfg.get("excess_token_handling", "drop")
if handling == "drop": if handling == "drop":
@@ -245,7 +257,9 @@ def load_prepare_preference_datasets(cfg):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences", desc="Truncating Long Sequences",
) )
LOG.info(f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens") LOG.info(
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
)
combined_datasets = concatenate_datasets(split_datasets) combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed) combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

View File

@@ -167,7 +167,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
# Get the handling method from config, default to "drop" for backward compatibility # Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("excess_token_handling", "drop") handling = cfg.get("excess_token_handling", "drop")
if handling == "drop": if handling == "drop":
# Use the existing drop_long_seq function for backward compatibility # Use the existing drop_long_seq function for backward compatibility
seq_handler = functools.partial( seq_handler = functools.partial(

View File

@@ -188,7 +188,9 @@ class AxolotlInputConfig(
sequence_len: int = Field(default=512) sequence_len: int = Field(default=512)
excess_token_handling: Literal["drop", "truncate"] = Field( excess_token_handling: Literal["drop", "truncate"] = Field(
default="drop", default="drop",
json_schema_extra={"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"}, json_schema_extra={
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"
},
) )
min_sample_len: int | None = None min_sample_len: int | None = None
max_prompt_len: int = Field( max_prompt_len: int = Field(

View File

@@ -235,7 +235,9 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return results return results
def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, handling="drop"): def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
""" """
Either drop or truncate samples whose sequence length is either too long (> sequence_len) Either drop or truncate samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len). or too short (< min_sequence_len).
@@ -264,35 +266,37 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han
if isinstance(input_ids[0], int): if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int) # Single example (input_ids is a list of int)
length = len(input_ids) length = len(input_ids)
# Handle samples that are too short - always drop them # Handle samples that are too short - always drop them
if length < min_sequence_len: if length < min_sequence_len:
return False if handling == "drop" else sample return False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it # If truncation is enabled and the sample is too long, truncate it
if length > sequence_len and handling == "truncate": if length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len] sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present # Also truncate attention_mask if present
if "attention_mask" in sample: if "attention_mask" in sample:
sample["attention_mask"] = sample["attention_mask"][:sequence_len] sample["attention_mask"] = sample["attention_mask"][:sequence_len]
# Also truncate labels if present # Also truncate labels if present
if "labels" in sample: if "labels" in sample:
sample["labels"] = sample["labels"][:sequence_len] sample["labels"] = sample["labels"][:sequence_len]
# Also truncate position_ids if present # Also truncate position_ids if present
if "position_ids" in sample: if "position_ids" in sample:
sample["position_ids"] = sample["position_ids"][:sequence_len] sample["position_ids"] = sample["position_ids"][:sequence_len]
# Update length if present # Update length if present
if "length" in sample: if "length" in sample:
sample["length"] = sequence_len sample["length"] = sequence_len
return sample return sample
# For drop mode or if the sample doesn't exceed max length # For drop mode or if the sample doesn't exceed max length
return min_sequence_len <= length <= sequence_len if handling == "drop" else sample return (
min_sequence_len <= length <= sequence_len if handling == "drop" else sample
)
# Batched (input_ids is a list of lists) # Batched (input_ids is a list of lists)
if handling == "drop": if handling == "drop":
@@ -305,31 +309,33 @@ def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, han
# Check each sequence in the batch # Check each sequence in the batch
for i, seq in enumerate(input_ids): for i, seq in enumerate(input_ids):
length = len(seq) length = len(seq)
# Skip sequences that are too short # Skip sequences that are too short
if length < min_sequence_len: if length < min_sequence_len:
continue continue
# Truncate sequences that are too long # Truncate sequences that are too long
if length > sequence_len: if length > sequence_len:
input_ids[i] = seq[:sequence_len] input_ids[i] = seq[:sequence_len]
# Also truncate attention_mask if present # Also truncate attention_mask if present
if "attention_mask" in sample: if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] sample["attention_mask"][i] = sample["attention_mask"][i][
:sequence_len
]
# Also truncate labels if present # Also truncate labels if present
if "labels" in sample: if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len] sample["labels"][i] = sample["labels"][i][:sequence_len]
# Also truncate position_ids if present # Also truncate position_ids if present
if "position_ids" in sample: if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
# Update length if present # Update length if present
if "length" in sample: if "length" in sample:
sample["length"][i] = sequence_len sample["length"][i] = sequence_len
return sample return sample
@@ -468,10 +474,14 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing( def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False, handling="drop" train_dataset,
sequence_len,
skip_position_ids=True,
drop_attention_mask=False,
handling="drop",
): ):
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len) drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
# Use filter for drop mode and map for truncate mode # Use filter for drop mode and map for truncate mode
if handling == "drop": if handling == "drop":
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
@@ -480,13 +490,15 @@ def process_pretraining_datasets_for_packing(
load_from_cache_file=False, load_from_cache_file=False,
) )
else: else:
truncate_fn = partial(truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling) truncate_fn = partial(
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling
)
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
truncate_fn, truncate_fn,
desc="Truncating Long Sequences", desc="Truncating Long Sequences",
load_from_cache_file=False, load_from_cache_file=False,
) )
if not skip_position_ids: if not skip_position_ids:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,