Compare commits
1 Commits
textui
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fccbadb79 |
@@ -180,15 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
def handle_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to filter.
|
||||
sequence_len: Maximum length for sequences to keep
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
"""
|
||||
Remove or truncate sequences that exceed the configured maximum length from a dataset.
|
||||
|
||||
Parameters:
|
||||
dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
|
||||
sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
|
||||
cfg (DictDefault): Configuration object with keys:
|
||||
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
|
||||
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
|
||||
- dataset_num_proc: number of processes to use for non-streaming datasets.
|
||||
- is_preprocess: when true, bypasses cached preprocessing during filtering.
|
||||
|
||||
Returns:
|
||||
Filtered dataset with long sequences removed.
|
||||
Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
|
||||
"""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
@@ -206,10 +211,13 @@ def handle_long_seq_in_dataset(
|
||||
)
|
||||
return dataset
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
raise_on_drop=excess_length_strategy == "raise",
|
||||
)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
@@ -230,7 +238,6 @@ def handle_long_seq_in_dataset(
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
truncate_long_seq,
|
||||
@@ -259,4 +266,4 @@ def handle_long_seq_in_dataset(
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
|
||||
return dataset
|
||||
return dataset
|
||||
@@ -201,16 +201,33 @@ def add_pose_position_ids(
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
"""
|
||||
Set the "length" field on a sample to the number of input tokens.
|
||||
|
||||
Parameters:
|
||||
sample (Mapping-like): A sample containing an "input_ids" sequence.
|
||||
|
||||
Returns:
|
||||
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
|
||||
"""
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
Return whether a sample (single or batched) should be kept based on sequence length constraints.
|
||||
|
||||
Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
|
||||
|
||||
Parameters:
|
||||
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
|
||||
sequence_len (int): Maximum allowed sequence length (inclusive).
|
||||
min_sequence_len (int): Minimum allowed sequence length (inclusive).
|
||||
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
|
||||
|
||||
Returns:
|
||||
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
@@ -225,12 +242,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
if isinstance(input_ids[0], int):
|
||||
# Single example (input_ids is a list of int)
|
||||
length = len(input_ids)
|
||||
if raise_on_drop and length > sequence_len:
|
||||
raise ValueError(
|
||||
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||
)
|
||||
return min_sequence_len <= length <= sequence_len
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
if raise_on_drop and length > sequence_len:
|
||||
raise ValueError(
|
||||
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||
)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
return results
|
||||
|
||||
@@ -715,4 +740,4 @@ def setup_trainer(
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
return trainer_builder.build(total_num_steps)
|
||||
return trainer_builder.build(total_num_steps)
|
||||
Reference in New Issue
Block a user