Fixes comments from winglian

This commit is contained in:
mhenrhcsen
2025-05-12 22:43:15 +02:00
parent be3c6bbd85
commit 5dd8f0b2b8
7 changed files with 262 additions and 90 deletions

View File

@@ -332,8 +332,8 @@ dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048 # The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens) # How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
excess_token_handling: drop sequence_len_overflow_handling: drop
# Pad inputs so each step uses constant sized buffers # Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:

View File

@@ -260,7 +260,9 @@ def encode_packed_pretraining(
# workaround by using the position id logic for now in trainer # workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn, drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function # pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"), handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", "drop"
),
) )
sampler = MultipackBatchSampler( sampler = MultipackBatchSampler(

View File

@@ -78,7 +78,11 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq( def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
): ):
result = None result = None
@@ -98,32 +102,44 @@ def drop_long_rl_seq(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"]) len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
if handling == "drop": # Truncate first, then drop if still invalid (although truncate should handle it)
result = (len_prompt + len_chosen) <= sequence_len and ( if handling == "truncate":
len_prompt + len_rejected
) <= sequence_len
# truncate
else:
# If both sequences fit, return sample unchanged # If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and ( if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len: ) <= sequence_len:
result = sample result = sample
else: else:
# For truncation, we need to truncate the chosen and rejected responses
# to fit within sequence_len, but preserve the prompt
# Calculate maximum response length that can fit with the prompt # Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt max_response_len = sequence_len - len_prompt
if max_response_len <= 0: if max_response_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt is already too long, behavior depends on handling
result = False if handling == "drop" else sample # If truncate is chosen, we technically can't truncate, but drop seems harsh.
# Returning the sample might be unexpected. Let's stick to the filter logic
# which would drop this in the `filter` step later if needed.
# For now, return sample to map, or False to filter.
# Let's simplify: truncate *should* result in a valid sample if possible.
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
# However, the filter/map logic is applied *after* this function.
# This function needs to return the *modified* sample for map, or bool for filter.
# Re-think: If handling==truncate, return the modified sample if possible.
# If prompt >= seq_len, modification is impossible. What should map return?
# Maybe return the original sample? But map expects *modified* sample.
# Let's stick to the original logic: if prompt is too long, return False for filter
# and original sample for map.
result = (
sample # For map, let downstream handle it if still invalid?
)
# Or maybe return None/empty dict? Let's return sample for now.
# If handling was drop, filter would remove this.
else: else:
# Truncate the chosen and rejected responses if needed # Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len: if len_chosen > max_response_len:
# Tokenize, truncate, and decode
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[ chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids" "input_ids"
][:max_response_len] ][:max_response_len]
@@ -132,15 +148,17 @@ def drop_long_rl_seq(
) )
if len_rejected > max_response_len: if len_rejected > max_response_len:
# Tokenize, truncate, and decode
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[ rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids" "input_ids"
][:max_response_len] ][:max_response_len]
sample["rejected"] = tokenizer.decode( sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True rejected_tokens, skip_special_tokens=True
) )
result = sample result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == "kto": elif rl == "kto":
if not (sample.get("prompt") and sample.get("completion")): if not (sample.get("prompt") and sample.get("completion")):
@@ -154,36 +172,36 @@ def drop_long_rl_seq(
tokenizer(completion, add_special_tokens=False)["input_ids"] tokenizer(completion, add_special_tokens=False)["input_ids"]
) )
if handling == "drop": # Truncate first
result = (len_prompt + len_completion) <= sequence_len if handling == "truncate":
# truncate
else:
# If sequence fits, return sample unchanged # If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len: if (len_prompt + len_completion) <= sequence_len:
result = sample result = sample
else: else:
# Calculate maximum completion length that can fit with the prompt # Calculate maximum completion length
max_completion_len = sequence_len - len_prompt max_completion_len = sequence_len - len_prompt
if max_completion_len <= 0: if max_completion_len <= 0:
# Prompt is already too long, we can't truncate effectively # Prompt too long, return sample for map
result = False if handling == "drop" else sample result = sample
else: else:
# Truncate the completion if needed # Truncate the completion if needed
if len_completion > max_completion_len: if len_completion > max_completion_len:
# Tokenize, truncate, and decode
completion_tokens = tokenizer( completion_tokens = tokenizer(
completion, add_special_tokens=False completion, add_special_tokens=False
)["input_ids"][:max_completion_len] )["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode( sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True completion_tokens, skip_special_tokens=True
) )
result = sample result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
elif rl == "grpo": elif rl == "grpo":
result = True if handling == "drop" else sample # GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
else: else:
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
@@ -234,21 +252,34 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = data_set split_datasets[i] = data_set
if not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
# Determine handling mode
handling = cfg.get("sequence_len_overflow_handling", "drop")
drop_long = partial( drop_long = partial(
drop_long_rl_seq, drop_long_rl_seq,
rl=_cfg.rl, rl=_cfg.rl,
tokenizer=tokenizer, tokenizer=tokenizer,
sequence_len=cfg.sequence_len, sequence_len=cfg.sequence_len,
handling=cfg.get("excess_token_handling", "drop"), handling=handling, # Pass the handling mode
) )
prior_len = len(split_datasets[i]) prior_len = len(split_datasets[i])
# Use filter for drop mode and map for truncate mode # Use map for truncate mode and filter for drop mode
handling = cfg.get("excess_token_handling", "drop") if handling == "truncate":
if handling == "drop": split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Note: Length might not change if truncation always occurs
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter( split_datasets[i] = split_datasets[i].filter(
drop_long, drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
@@ -258,16 +289,6 @@ def load_prepare_preference_datasets(cfg):
LOG.warning( LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}" f"Dropped {dropped} long samples from dataset index {i}"
) )
else:
split_datasets[i] = split_datasets[i].map(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
LOG.info(
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
)
combined_datasets = concatenate_datasets(split_datasets) combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed) combined_datasets = combined_datasets.shuffle(seed=cfg.seed)

View File

@@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -166,23 +166,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
return dataset return dataset
# Get the handling method from config, default to "drop" for backward compatibility # Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("excess_token_handling", "drop") handling = cfg.get("sequence_len_overflow_handling", "drop")
if handling == "drop": # Use the new function with the specified handling mode
# Use the existing drop_long_seq function for backward compatibility seq_handler = functools.partial(
seq_handler = functools.partial( truncate_or_drop_long_seq,
drop_long_seq, sequence_len=cfg.sequence_len,
sequence_len=cfg.sequence_len, min_sequence_len=cfg.min_sample_len,
min_sequence_len=cfg.min_sample_len, handling=handling,
) )
else: # handling == "truncate"
# Use the new function with truncate mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
try: try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True) ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
@@ -206,12 +198,21 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
if handling == "drop": if handling == "truncate":
drop_long_kwargs["desc"] = "Dropping Long Sequences"
else:
drop_long_kwargs["desc"] = "Truncating Long Sequences" drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = "Dropping Long Sequences"
if handling == "drop": if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode # Use filter for drop mode
dataset = dataset.filter( dataset = dataset.filter(
seq_handler, seq_handler,
@@ -223,14 +224,5 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
dropped = prior_len - len(dataset) dropped = prior_len - len(dataset)
if dropped: if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset") LOG.warning(f"Dropped {dropped} long samples from dataset")
else:
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
return dataset return dataset

View File

@@ -186,10 +186,10 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(default=512) sequence_len: int = Field(default=512)
excess_token_handling: Literal["drop", "truncate"] = Field( sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop", default="drop",
json_schema_extra={ json_schema_extra={
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate" "description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
}, },
) )
min_sample_len: int | None = None min_sample_len: int | None = None

View File

@@ -484,22 +484,24 @@ def process_pretraining_datasets_for_packing(
drop_attention_mask=False, drop_attention_mask=False,
handling="drop", handling="drop",
): ):
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len) # Define the function to use for handling sequences based on the mode
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use filter for drop mode and map for truncate mode # Use map for truncate mode and filter for drop mode
if handling == "drop": if handling == "truncate":
train_dataset = train_dataset.filter( train_dataset = train_dataset.map(
drop_long_fn, seq_handler_fn,
desc="Dropping Long Sequences", desc="Truncating Long Sequences",
load_from_cache_file=False, load_from_cache_file=False,
) )
else: else: # handling == "drop"
truncate_fn = partial( train_dataset = train_dataset.filter(
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling seq_handler_fn, # Use the same function, it returns boolean for drop mode
) desc="Dropping Long Sequences",
train_dataset = train_dataset.map(
truncate_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False, load_from_cache_file=False,
) )

155
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,155 @@
import unittest
from functools import partial
import pytest
# Assuming the function is in axolotl.utils.trainer
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()