Fixes comments from winglian
This commit is contained in:
@@ -332,8 +332,8 @@ dataset_shard_idx:
|
|||||||
# The maximum length of an input to train with, this should typically be less than 2048
|
# The maximum length of an input to train with, this should typically be less than 2048
|
||||||
# as most models have a token/context limit of 2048
|
# as most models have a token/context limit of 2048
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens)
|
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
|
||||||
excess_token_handling: drop
|
sequence_len_overflow_handling: drop
|
||||||
# Pad inputs so each step uses constant sized buffers
|
# Pad inputs so each step uses constant sized buffers
|
||||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|||||||
@@ -260,7 +260,9 @@ def encode_packed_pretraining(
|
|||||||
# workaround by using the position id logic for now in trainer
|
# workaround by using the position id logic for now in trainer
|
||||||
drop_attention_mask=multipack_attn,
|
drop_attention_mask=multipack_attn,
|
||||||
# pass through handling mode from config via ds_wrapper function
|
# pass through handling mode from config via ds_wrapper function
|
||||||
handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"),
|
handling=getattr(ds_wrapper, "cfg", {}).get(
|
||||||
|
"sequence_len_overflow_handling", "drop"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
|
|||||||
@@ -78,7 +78,11 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def drop_long_rl_seq(
|
def drop_long_rl_seq(
|
||||||
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
|
sample,
|
||||||
|
rl,
|
||||||
|
tokenizer,
|
||||||
|
sequence_len,
|
||||||
|
handling="drop", # Use the default handling mode
|
||||||
):
|
):
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
@@ -98,32 +102,44 @@ def drop_long_rl_seq(
|
|||||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||||
|
|
||||||
if handling == "drop":
|
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
if handling == "truncate":
|
||||||
len_prompt + len_rejected
|
|
||||||
) <= sequence_len
|
|
||||||
|
|
||||||
# truncate
|
|
||||||
else:
|
|
||||||
# If both sequences fit, return sample unchanged
|
# If both sequences fit, return sample unchanged
|
||||||
if (len_prompt + len_chosen) <= sequence_len and (
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
) <= sequence_len:
|
) <= sequence_len:
|
||||||
result = sample
|
result = sample
|
||||||
else:
|
else:
|
||||||
# For truncation, we need to truncate the chosen and rejected responses
|
|
||||||
# to fit within sequence_len, but preserve the prompt
|
|
||||||
|
|
||||||
# Calculate maximum response length that can fit with the prompt
|
# Calculate maximum response length that can fit with the prompt
|
||||||
max_response_len = sequence_len - len_prompt
|
max_response_len = sequence_len - len_prompt
|
||||||
|
|
||||||
if max_response_len <= 0:
|
if max_response_len <= 0:
|
||||||
# Prompt is already too long, we can't truncate effectively
|
# Prompt is already too long, behavior depends on handling
|
||||||
result = False if handling == "drop" else sample
|
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
|
||||||
|
# Returning the sample might be unexpected. Let's stick to the filter logic
|
||||||
|
# which would drop this in the `filter` step later if needed.
|
||||||
|
# For now, return sample to map, or False to filter.
|
||||||
|
# Let's simplify: truncate *should* result in a valid sample if possible.
|
||||||
|
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
|
||||||
|
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
|
||||||
|
# However, the filter/map logic is applied *after* this function.
|
||||||
|
# This function needs to return the *modified* sample for map, or bool for filter.
|
||||||
|
|
||||||
|
# Re-think: If handling==truncate, return the modified sample if possible.
|
||||||
|
# If prompt >= seq_len, modification is impossible. What should map return?
|
||||||
|
# Maybe return the original sample? But map expects *modified* sample.
|
||||||
|
# Let's stick to the original logic: if prompt is too long, return False for filter
|
||||||
|
# and original sample for map.
|
||||||
|
|
||||||
|
result = (
|
||||||
|
sample # For map, let downstream handle it if still invalid?
|
||||||
|
)
|
||||||
|
# Or maybe return None/empty dict? Let's return sample for now.
|
||||||
|
# If handling was drop, filter would remove this.
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Truncate the chosen and rejected responses if needed
|
# Truncate the chosen and rejected responses if needed
|
||||||
if len_chosen > max_response_len:
|
if len_chosen > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
][:max_response_len]
|
][:max_response_len]
|
||||||
@@ -132,15 +148,17 @@ def drop_long_rl_seq(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len_rejected > max_response_len:
|
if len_rejected > max_response_len:
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
][:max_response_len]
|
][:max_response_len]
|
||||||
sample["rejected"] = tokenizer.decode(
|
sample["rejected"] = tokenizer.decode(
|
||||||
rejected_tokens, skip_special_tokens=True
|
rejected_tokens, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
result = sample
|
result = sample
|
||||||
|
else: # handling == "drop"
|
||||||
|
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||||
|
len_prompt + len_rejected
|
||||||
|
) <= sequence_len
|
||||||
|
|
||||||
elif rl == "kto":
|
elif rl == "kto":
|
||||||
if not (sample.get("prompt") and sample.get("completion")):
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
@@ -154,36 +172,36 @@ def drop_long_rl_seq(
|
|||||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if handling == "drop":
|
# Truncate first
|
||||||
result = (len_prompt + len_completion) <= sequence_len
|
if handling == "truncate":
|
||||||
|
|
||||||
# truncate
|
|
||||||
else:
|
|
||||||
# If sequence fits, return sample unchanged
|
# If sequence fits, return sample unchanged
|
||||||
if (len_prompt + len_completion) <= sequence_len:
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
result = sample
|
result = sample
|
||||||
else:
|
else:
|
||||||
# Calculate maximum completion length that can fit with the prompt
|
# Calculate maximum completion length
|
||||||
max_completion_len = sequence_len - len_prompt
|
max_completion_len = sequence_len - len_prompt
|
||||||
|
|
||||||
if max_completion_len <= 0:
|
if max_completion_len <= 0:
|
||||||
# Prompt is already too long, we can't truncate effectively
|
# Prompt too long, return sample for map
|
||||||
result = False if handling == "drop" else sample
|
result = sample
|
||||||
else:
|
else:
|
||||||
# Truncate the completion if needed
|
# Truncate the completion if needed
|
||||||
if len_completion > max_completion_len:
|
if len_completion > max_completion_len:
|
||||||
# Tokenize, truncate, and decode
|
|
||||||
completion_tokens = tokenizer(
|
completion_tokens = tokenizer(
|
||||||
completion, add_special_tokens=False
|
completion, add_special_tokens=False
|
||||||
)["input_ids"][:max_completion_len]
|
)["input_ids"][:max_completion_len]
|
||||||
sample["completion"] = tokenizer.decode(
|
sample["completion"] = tokenizer.decode(
|
||||||
completion_tokens, skip_special_tokens=True
|
completion_tokens, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
result = sample
|
result = sample
|
||||||
|
else: # handling == "drop"
|
||||||
|
result = (len_prompt + len_completion) <= sequence_len
|
||||||
|
|
||||||
elif rl == "grpo":
|
elif rl == "grpo":
|
||||||
result = True if handling == "drop" else sample
|
# GRPO doesn't involve sequence length checks in the same way?
|
||||||
|
# The original code returned True for drop. What should it return for truncate?
|
||||||
|
# Let's assume for now it always passes.
|
||||||
|
result = sample if handling == "truncate" else True
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
@@ -234,21 +252,34 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
split_datasets[i] = data_set
|
split_datasets[i] = data_set
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
|
# Determine handling mode
|
||||||
|
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||||
|
|
||||||
drop_long = partial(
|
drop_long = partial(
|
||||||
drop_long_rl_seq,
|
drop_long_rl_seq,
|
||||||
rl=_cfg.rl,
|
rl=_cfg.rl,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
sequence_len=cfg.sequence_len,
|
sequence_len=cfg.sequence_len,
|
||||||
handling=cfg.get("excess_token_handling", "drop"),
|
handling=handling, # Pass the handling mode
|
||||||
)
|
)
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
|
|
||||||
# Use filter for drop mode and map for truncate mode
|
# Use map for truncate mode and filter for drop mode
|
||||||
handling = cfg.get("excess_token_handling", "drop")
|
if handling == "truncate":
|
||||||
if handling == "drop":
|
split_datasets[i] = split_datasets[i].map(
|
||||||
|
drop_long, # Function now returns modified sample or original
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Truncating Long Sequences",
|
||||||
|
)
|
||||||
|
# Note: Length might not change if truncation always occurs
|
||||||
|
LOG.info(
|
||||||
|
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
||||||
|
)
|
||||||
|
else: # handling == "drop"
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
drop_long,
|
drop_long, # Function now returns boolean
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
@@ -258,16 +289,6 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Dropped {dropped} long samples from dataset index {i}"
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
split_datasets[i] = split_datasets[i].map(
|
|
||||||
drop_long,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Truncating Long Sequences",
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq
|
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -166,23 +166,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
# Get the handling method from config, default to "drop" for backward compatibility
|
# Get the handling method from config, default to "drop" for backward compatibility
|
||||||
handling = cfg.get("excess_token_handling", "drop")
|
handling = cfg.get("sequence_len_overflow_handling", "drop")
|
||||||
|
|
||||||
if handling == "drop":
|
# Use the new function with the specified handling mode
|
||||||
# Use the existing drop_long_seq function for backward compatibility
|
seq_handler = functools.partial(
|
||||||
seq_handler = functools.partial(
|
truncate_or_drop_long_seq,
|
||||||
drop_long_seq,
|
sequence_len=cfg.sequence_len,
|
||||||
sequence_len=cfg.sequence_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
handling=handling,
|
||||||
)
|
)
|
||||||
else: # handling == "truncate"
|
|
||||||
# Use the new function with truncate mode
|
|
||||||
seq_handler = functools.partial(
|
|
||||||
truncate_or_drop_long_seq,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len,
|
|
||||||
handling=handling,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||||
@@ -206,12 +198,21 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
if handling == "drop":
|
if handling == "truncate":
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
|
||||||
else:
|
|
||||||
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||||
|
else: # handling == "drop"
|
||||||
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
|
|
||||||
if handling == "drop":
|
if handling == "truncate":
|
||||||
|
# Use map for truncate mode
|
||||||
|
dataset = dataset.map(
|
||||||
|
seq_handler,
|
||||||
|
batched=True,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
|
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
||||||
|
else: # handling == "drop"
|
||||||
# Use filter for drop mode
|
# Use filter for drop mode
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
seq_handler,
|
seq_handler,
|
||||||
@@ -223,14 +224,5 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
dropped = prior_len - len(dataset)
|
dropped = prior_len - len(dataset)
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||||
else:
|
|
||||||
# Use map for truncate mode
|
|
||||||
dataset = dataset.map(
|
|
||||||
seq_handler,
|
|
||||||
batched=True,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -186,10 +186,10 @@ class AxolotlInputConfig(
|
|||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
excess_token_handling: Literal["drop", "truncate"] = Field(
|
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
|
||||||
default="drop",
|
default="drop",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"
|
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
|
|||||||
@@ -484,22 +484,24 @@ def process_pretraining_datasets_for_packing(
|
|||||||
drop_attention_mask=False,
|
drop_attention_mask=False,
|
||||||
handling="drop",
|
handling="drop",
|
||||||
):
|
):
|
||||||
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
|
# Define the function to use for handling sequences based on the mode
|
||||||
|
seq_handler_fn = partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=sequence_len,
|
||||||
|
handling=handling, # Pass handling mode
|
||||||
|
)
|
||||||
|
|
||||||
# Use filter for drop mode and map for truncate mode
|
# Use map for truncate mode and filter for drop mode
|
||||||
if handling == "drop":
|
if handling == "truncate":
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.map(
|
||||||
drop_long_fn,
|
seq_handler_fn,
|
||||||
desc="Dropping Long Sequences",
|
desc="Truncating Long Sequences",
|
||||||
load_from_cache_file=False,
|
load_from_cache_file=False,
|
||||||
)
|
)
|
||||||
else:
|
else: # handling == "drop"
|
||||||
truncate_fn = partial(
|
train_dataset = train_dataset.filter(
|
||||||
truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling
|
seq_handler_fn, # Use the same function, it returns boolean for drop mode
|
||||||
)
|
desc="Dropping Long Sequences",
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
truncate_fn,
|
|
||||||
desc="Truncating Long Sequences",
|
|
||||||
load_from_cache_file=False,
|
load_from_cache_file=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
155
tests/test_trainer_utils.py
Normal file
155
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
import unittest
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Assuming the function is in axolotl.utils.trainer
|
||||||
|
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for truncate_or_drop_long_seq
|
||||||
|
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test suite for truncate_or_drop_long_seq function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Example sequence length settings
|
||||||
|
self.sequence_len = 10
|
||||||
|
self.min_sequence_len = 3
|
||||||
|
|
||||||
|
def test_drop_mode_single(self):
|
||||||
|
"""Test drop mode with single examples."""
|
||||||
|
handler = partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=self.sequence_len,
|
||||||
|
min_sequence_len=self.min_sequence_len,
|
||||||
|
handling="drop",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Too short
|
||||||
|
sample_short = {"input_ids": [1, 2]}
|
||||||
|
self.assertFalse(handler(sample_short))
|
||||||
|
|
||||||
|
# Too long
|
||||||
|
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||||
|
self.assertFalse(handler(sample_long))
|
||||||
|
|
||||||
|
# Just right
|
||||||
|
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||||
|
self.assertTrue(handler(sample_ok))
|
||||||
|
|
||||||
|
# Empty
|
||||||
|
sample_empty = {"input_ids": []}
|
||||||
|
self.assertFalse(handler(sample_empty))
|
||||||
|
|
||||||
|
def test_truncate_mode_single(self):
|
||||||
|
"""Test truncate mode with single examples."""
|
||||||
|
handler = partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=self.sequence_len,
|
||||||
|
min_sequence_len=self.min_sequence_len,
|
||||||
|
handling="truncate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||||
|
# but the function itself might return the sample or False based on impl.)
|
||||||
|
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||||
|
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||||
|
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||||
|
result_short = handler(sample_short)
|
||||||
|
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||||
|
|
||||||
|
# Too long
|
||||||
|
original_long = list(range(self.sequence_len + 5))
|
||||||
|
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||||
|
result_long = handler(sample_long)
|
||||||
|
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||||
|
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||||
|
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||||
|
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||||
|
|
||||||
|
|
||||||
|
# Just right
|
||||||
|
sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))}
|
||||||
|
result_ok = handler(sample_ok)
|
||||||
|
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||||
|
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||||
|
|
||||||
|
# Empty
|
||||||
|
sample_empty = {"input_ids": [], "labels": []}
|
||||||
|
result_empty = handler(sample_empty)
|
||||||
|
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_mode_batched(self):
|
||||||
|
"""Test drop mode with batched examples."""
|
||||||
|
handler = partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=self.sequence_len,
|
||||||
|
min_sequence_len=self.min_sequence_len,
|
||||||
|
handling="drop",
|
||||||
|
)
|
||||||
|
sample = {
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2], # Too short
|
||||||
|
list(range(self.sequence_len + 1)), # Too long
|
||||||
|
list(range(self.sequence_len)), # OK (len = 10)
|
||||||
|
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||||
|
[], # Empty
|
||||||
|
]
|
||||||
|
}
|
||||||
|
expected = [False, False, True, True, False]
|
||||||
|
self.assertEqual(handler(sample), expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_mode_batched(self):
|
||||||
|
"""Test truncate mode with batched examples."""
|
||||||
|
handler = partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=self.sequence_len,
|
||||||
|
min_sequence_len=self.min_sequence_len,
|
||||||
|
handling="truncate",
|
||||||
|
)
|
||||||
|
sample = {
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2], # Too short
|
||||||
|
list(range(self.sequence_len + 5)), # Too long
|
||||||
|
list(range(self.sequence_len)), # OK
|
||||||
|
list(range(self.min_sequence_len)), # OK
|
||||||
|
[], # Empty
|
||||||
|
],
|
||||||
|
"labels": [ # Add labels to test truncation
|
||||||
|
[1, 2],
|
||||||
|
list(range(self.sequence_len + 5)),
|
||||||
|
list(range(self.sequence_len)),
|
||||||
|
list(range(self.min_sequence_len)),
|
||||||
|
[],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = handler(sample)
|
||||||
|
|
||||||
|
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||||
|
expected_input_ids = [
|
||||||
|
[1, 2], # Unchanged (too short)
|
||||||
|
list(range(self.sequence_len)), # Truncated
|
||||||
|
list(range(self.sequence_len)), # Unchanged (OK)
|
||||||
|
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||||
|
[], # Unchanged (Empty)
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
[1, 2], # Unchanged (too short)
|
||||||
|
list(range(self.sequence_len)), # Truncated
|
||||||
|
list(range(self.sequence_len)), # Unchanged (OK)
|
||||||
|
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||||
|
[], # Unchanged (Empty)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||||
|
self.assertEqual(result["labels"], expected_labels)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user