Fixes comments from winglian

This commit is contained in:
mhenrhcsen
2025-05-12 22:43:15 +02:00
parent be3c6bbd85
commit 5dd8f0b2b8
7 changed files with 262 additions and 90 deletions

View File

@@ -332,8 +332,8 @@ dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens)
excess_token_handling: drop
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
sequence_len_overflow_handling: drop
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:

View File

@@ -260,7 +260,9 @@ def encode_packed_pretraining(
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"),
handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", "drop"
),
)
sampler = MultipackBatchSampler(

View File

@@ -78,7 +78,11 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
):
result = None
@@ -98,32 +102,44 @@ def drop_long_rl_seq(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
if handling == "drop":
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
# truncate
else:
# Truncate first, then drop if still invalid (although truncate should handle it)
if handling == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len:
result = sample
else:
# For truncation, we need to truncate the chosen and rejected responses
# to fit within sequence_len, but preserve the prompt
# Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt
if max_response_len <= 0:
# Prompt is already too long, we can't truncate effectively
result = False if handling == "drop" else sample
# Prompt is already too long, behavior depends on handling
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
# Returning the sample might be unexpected. Let's stick to the filter logic
# which would drop this in the `filter` step later if needed.
# For now, return sample to map, or False to filter.
# Let's simplify: truncate *should* result in a valid sample if possible.
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
# However, the filter/map logic is applied *after* this function.
# This function needs to return the *modified* sample for map, or bool for filter.
# Re-think: If handling==truncate, return the modified sample if possible.
# If prompt >= seq_len, modification is impossible. What should map return?
# Maybe return the original sample? But map expects *modified* sample.
# Let's stick to the original logic: if prompt is too long, return False for filter
# and original sample for map.
result = (
sample # For map, let downstream handle it if still invalid?
)
# Or maybe return None/empty dict? Let's return sample for now.
# If handling was drop, filter would remove this.
else:
# Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len:
# Tokenize, truncate, and decode
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids"
][:max_response_len]
@@ -132,15 +148,17 @@ def drop_long_rl_seq(
)
if len_rejected > max_response_len:
# Tokenize, truncate, and decode
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
@@ -154,36 +172,36 @@ def drop_long_rl_seq(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
if handling == "drop":
result = (len_prompt + len_completion) <= sequence_len
# truncate
else:
# Truncate first
if handling == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
else:
# Calculate maximum completion length that can fit with the prompt
# Calculate maximum completion length
max_completion_len = sequence_len - len_prompt
if max_completion_len <= 0:
# Prompt is already too long, we can't truncate effectively
result = False if handling == "drop" else sample
# Prompt too long, return sample for map
result = sample
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
# Tokenize, truncate, and decode
completion_tokens = tokenizer(
completion, add_special_tokens=False
)["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
elif rl == "grpo":
result = True if handling == "drop" else sample
# GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
else:
raise ValueError("Unknown RL type")
@@ -234,21 +252,34 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
# Determine handling mode
handling = cfg.get("sequence_len_overflow_handling", "drop")
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
handling=cfg.get("excess_token_handling", "drop"),
handling=handling, # Pass the handling mode
)
prior_len = len(split_datasets[i])
# Use filter for drop mode and map for truncate mode
handling = cfg.get("excess_token_handling", "drop")
if handling == "drop":
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Note: Length might not change if truncation always occurs
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter(
drop_long,
drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
@@ -258,16 +289,6 @@ def load_prepare_preference_datasets(cfg):
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
else:
split_datasets[i] = split_datasets[i].map(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
LOG.info(
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

View File

@@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq
from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = logging.getLogger(__name__)
@@ -166,23 +166,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
return dataset
# Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("excess_token_handling", "drop")
handling = cfg.get("sequence_len_overflow_handling", "drop")
if handling == "drop":
# Use the existing drop_long_seq function for backward compatibility
seq_handler = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)
else: # handling == "truncate"
# Use the new function with truncate mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
# Use the new function with the specified handling mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
@@ -206,12 +198,21 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
drop_long_kwargs = {}
if filter_map_kwargs:
if handling == "drop":
drop_long_kwargs["desc"] = "Dropping Long Sequences"
else:
if handling == "truncate":
drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = "Dropping Long Sequences"
if handling == "drop":
if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode
dataset = dataset.filter(
seq_handler,
@@ -223,14 +224,5 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
else:
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
return dataset

View File

@@ -186,10 +186,10 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(default=512)
excess_token_handling: Literal["drop", "truncate"] = Field(
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop",
json_schema_extra={
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
},
)
min_sample_len: int | None = None

View File

@@ -484,22 +484,24 @@ def process_pretraining_datasets_for_packing(
drop_attention_mask=False,
handling="drop",
):
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
# Define the function to use for handling sequences based on the mode
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use filter for drop mode and map for truncate mode
if handling == "drop":
train_dataset = train_dataset.filter(
drop_long_fn,
desc="Dropping Long Sequences",
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
train_dataset = train_dataset.map(
seq_handler_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False,
)
else:
truncate_fn = partial(
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling
)
train_dataset = train_dataset.map(
truncate_fn,
desc="Truncating Long Sequences",
else: # handling == "drop"
train_dataset = train_dataset.filter(
seq_handler_fn, # Use the same function, it returns boolean for drop mode
desc="Dropping Long Sequences",
load_from_cache_file=False,
)

155
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,155 @@
import unittest
from functools import partial
import pytest
# Assuming the function is in axolotl.utils.trainer
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()