Fixes comments from winglian
This commit is contained in:
@@ -332,8 +332,8 @@ dataset_shard_idx:
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens)
|
||||
excess_token_handling: drop
|
||||
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
|
||||
sequence_len_overflow_handling: drop
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
|
||||
@@ -260,7 +260,9 @@ def encode_packed_pretraining(
|
||||
# workaround by using the position id logic for now in trainer
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"),
|
||||
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling", "drop"
|
||||
),
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
|
||||
@@ -78,7 +78,11 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
|
||||
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
|
||||
sample,
|
||||
rl,
|
||||
tokenizer,
|
||||
sequence_len,
|
||||
handling="drop", # Use the default handling mode
|
||||
):
|
||||
result = None
|
||||
|
||||
@@ -98,32 +102,44 @@ def drop_long_rl_seq(
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
if handling == "drop":
|
||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
# truncate
|
||||
else:
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
if handling == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# For truncation, we need to truncate the chosen and rejected responses
|
||||
# to fit within sequence_len, but preserve the prompt
|
||||
|
||||
# Calculate maximum response length that can fit with the prompt
|
||||
max_response_len = sequence_len - len_prompt
|
||||
|
||||
if max_response_len <= 0:
|
||||
# Prompt is already too long, we can't truncate effectively
|
||||
result = False if handling == "drop" else sample
|
||||
# Prompt is already too long, behavior depends on handling
|
||||
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
|
||||
# Returning the sample might be unexpected. Let's stick to the filter logic
|
||||
# which would drop this in the `filter` step later if needed.
|
||||
# For now, return sample to map, or False to filter.
|
||||
# Let's simplify: truncate *should* result in a valid sample if possible.
|
||||
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
|
||||
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
|
||||
# However, the filter/map logic is applied *after* this function.
|
||||
# This function needs to return the *modified* sample for map, or bool for filter.
|
||||
|
||||
# Re-think: If handling==truncate, return the modified sample if possible.
|
||||
# If prompt >= seq_len, modification is impossible. What should map return?
|
||||
# Maybe return the original sample? But map expects *modified* sample.
|
||||
# Let's stick to the original logic: if prompt is too long, return False for filter
|
||||
# and original sample for map.
|
||||
|
||||
result = (
|
||||
sample # For map, let downstream handle it if still invalid?
|
||||
)
|
||||
# Or maybe return None/empty dict? Let's return sample for now.
|
||||
# If handling was drop, filter would remove this.
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
if len_chosen > max_response_len:
|
||||
# Tokenize, truncate, and decode
|
||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
@@ -132,15 +148,17 @@ def drop_long_rl_seq(
|
||||
)
|
||||
|
||||
if len_rejected > max_response_len:
|
||||
# Tokenize, truncate, and decode
|
||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["rejected"] = tokenizer.decode(
|
||||
rejected_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
elif rl == "kto":
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
@@ -154,36 +172,36 @@ def drop_long_rl_seq(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
if handling == "drop":
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
# truncate
|
||||
else:
|
||||
# Truncate first
|
||||
if handling == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum completion length that can fit with the prompt
|
||||
# Calculate maximum completion length
|
||||
max_completion_len = sequence_len - len_prompt
|
||||
|
||||
if max_completion_len <= 0:
|
||||
# Prompt is already too long, we can't truncate effectively
|
||||
result = False if handling == "drop" else sample
|
||||
# Prompt too long, return sample for map
|
||||
result = sample
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
if len_completion > max_completion_len:
|
||||
# Tokenize, truncate, and decode
|
||||
completion_tokens = tokenizer(
|
||||
completion, add_special_tokens=False
|
||||
)["input_ids"][:max_completion_len]
|
||||
sample["completion"] = tokenizer.decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
elif rl == "grpo":
|
||||
result = True if handling == "drop" else sample
|
||||
# GRPO doesn't involve sequence length checks in the same way?
|
||||
# The original code returned True for drop. What should it return for truncate?
|
||||
# Let's assume for now it always passes.
|
||||
result = sample if handling == "truncate" else True
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
@@ -234,21 +252,34 @@ def load_prepare_preference_datasets(cfg):
|
||||
split_datasets[i] = data_set
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Determine handling mode
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
handling=cfg.get("excess_token_handling", "drop"),
|
||||
handling=handling, # Pass the handling mode
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
|
||||
# Use filter for drop mode and map for truncate mode
|
||||
handling = cfg.get("excess_token_handling", "drop")
|
||||
if handling == "drop":
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
drop_long, # Function now returns modified sample or original
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
# Note: Length might not change if truncation always occurs
|
||||
LOG.info(
|
||||
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
||||
)
|
||||
else: # handling == "drop"
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
drop_long, # Function now returns boolean
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
@@ -258,16 +289,6 @@ def load_prepare_preference_datasets(cfg):
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
else:
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
LOG.info(
|
||||
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
|
||||
@@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -166,23 +166,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
return dataset
|
||||
|
||||
# Get the handling method from config, default to "drop" for backward compatibility
|
||||
handling = cfg.get("excess_token_handling", "drop")
|
||||
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||
|
||||
if handling == "drop":
|
||||
# Use the existing drop_long_seq function for backward compatibility
|
||||
seq_handler = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
else: # handling == "truncate"
|
||||
# Use the new function with truncate mode
|
||||
seq_handler = functools.partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
handling=handling,
|
||||
)
|
||||
# Use the new function with the specified handling mode
|
||||
seq_handler = functools.partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
handling=handling,
|
||||
)
|
||||
|
||||
try:
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
@@ -206,12 +198,21 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
if handling == "drop":
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
else:
|
||||
if handling == "truncate":
|
||||
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||
else: # handling == "drop"
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
|
||||
if handling == "drop":
|
||||
if handling == "truncate":
|
||||
# Use map for truncate mode
|
||||
dataset = dataset.map(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
||||
else: # handling == "drop"
|
||||
# Use filter for drop mode
|
||||
dataset = dataset.filter(
|
||||
seq_handler,
|
||||
@@ -223,14 +224,5 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
else:
|
||||
# Use map for truncate mode
|
||||
dataset = dataset.map(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -186,10 +186,10 @@ class AxolotlInputConfig(
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
sequence_len: int = Field(default=512)
|
||||
excess_token_handling: Literal["drop", "truncate"] = Field(
|
||||
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
|
||||
default="drop",
|
||||
json_schema_extra={
|
||||
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"
|
||||
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
|
||||
},
|
||||
)
|
||||
min_sample_len: int | None = None
|
||||
|
||||
@@ -484,22 +484,24 @@ def process_pretraining_datasets_for_packing(
|
||||
drop_attention_mask=False,
|
||||
handling="drop",
|
||||
):
|
||||
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
# Define the function to use for handling sequences based on the mode
|
||||
seq_handler_fn = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
handling=handling, # Pass handling mode
|
||||
)
|
||||
|
||||
# Use filter for drop mode and map for truncate mode
|
||||
if handling == "drop":
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long_fn,
|
||||
desc="Dropping Long Sequences",
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
train_dataset = train_dataset.map(
|
||||
seq_handler_fn,
|
||||
desc="Truncating Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
else:
|
||||
truncate_fn = partial(
|
||||
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling
|
||||
)
|
||||
train_dataset = train_dataset.map(
|
||||
truncate_fn,
|
||||
desc="Truncating Long Sequences",
|
||||
else: # handling == "drop"
|
||||
train_dataset = train_dataset.filter(
|
||||
seq_handler_fn, # Use the same function, it returns boolean for drop mode
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
|
||||
155
tests/test_trainer_utils.py
Normal file
155
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
||||
# Assuming the function is in axolotl.utils.trainer
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
|
||||
# Test cases for truncate_or_drop_long_seq
|
||||
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
"""
|
||||
Test suite for truncate_or_drop_long_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""Test drop mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
|
||||
# Too short
|
||||
sample_short = {"input_ids": [1, 2]}
|
||||
self.assertFalse(handler(sample_short))
|
||||
|
||||
# Too long
|
||||
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||
self.assertFalse(handler(sample_long))
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||
self.assertTrue(handler(sample_ok))
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": []}
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""Test truncate mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
|
||||
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||
# but the function itself might return the sample or False based on impl.)
|
||||
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||
result_short = handler(sample_short)
|
||||
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||
|
||||
# Too long
|
||||
original_long = list(range(self.sequence_len + 5))
|
||||
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||
result_long = handler(sample_long)
|
||||
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))}
|
||||
result_ok = handler(sample_ok)
|
||||
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": [], "labels": []}
|
||||
result_empty = handler(sample_empty)
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""Test drop mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 1)), # Too long
|
||||
list(range(self.sequence_len)), # OK (len = 10)
|
||||
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||
[], # Empty
|
||||
]
|
||||
}
|
||||
expected = [False, False, True, True, False]
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""Test truncate mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 5)), # Too long
|
||||
list(range(self.sequence_len)), # OK
|
||||
list(range(self.min_sequence_len)), # OK
|
||||
[], # Empty
|
||||
],
|
||||
"labels": [ # Add labels to test truncation
|
||||
[1, 2],
|
||||
list(range(self.sequence_len + 5)),
|
||||
list(range(self.sequence_len)),
|
||||
list(range(self.min_sequence_len)),
|
||||
[],
|
||||
],
|
||||
}
|
||||
|
||||
result = handler(sample)
|
||||
|
||||
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||
expected_input_ids = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
expected_labels = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||
self.assertEqual(result["labels"], expected_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user