Compare commits
21 Commits
online-top
...
775-option
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f2d196476 | ||
|
|
f1a8474400 | ||
|
|
dc5887c652 | ||
|
|
54b542d312 | ||
|
|
30a89b07b9 | ||
|
|
746c03b097 | ||
|
|
47b3fe8af3 | ||
|
|
f5a3e3529e | ||
|
|
618b008e36 | ||
|
|
5d7a61576d | ||
|
|
5ecf22b54e | ||
|
|
9c5b8da22f | ||
|
|
fea6649518 | ||
|
|
124ad2b968 | ||
|
|
767c2340f1 | ||
|
|
f6623c34cc | ||
|
|
5dd8f0b2b8 | ||
|
|
be3c6bbd85 | ||
|
|
f07db4f853 | ||
|
|
17a5838d38 | ||
|
|
9f68918f13 |
@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -10,6 +10,7 @@ from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||
@@ -259,6 +260,15 @@ def encode_packed_pretraining(
|
||||
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
|
||||
# workaround by using the position id logic for now in trainer
|
||||
drop_attention_mask=multipack_attn,
|
||||
# pass through handling mode from config via ds_wrapper function
|
||||
handling=(
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"sequence_len_overflow_handling",
|
||||
getattr(ds_wrapper, "cfg", {}).get(
|
||||
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
|
||||
@@ -122,6 +122,14 @@ def _map_dataset(
|
||||
return dataset
|
||||
|
||||
|
||||
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||
"""
|
||||
Backward-compatibility wrapper for legacy imports in tests.
|
||||
Delegates to the new predicate.
|
||||
"""
|
||||
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
|
||||
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
@@ -155,11 +163,51 @@ def _drop_long_sequences(
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
# Truncate first, then drop if still invalid (although truncate should handle it)
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If both sequences fit, return sample unchanged
|
||||
if (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum response length that can fit with the prompt
|
||||
max_response_len = sequence_len - len_prompt
|
||||
|
||||
if rl is RLType.KTO:
|
||||
if max_response_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = False
|
||||
|
||||
else:
|
||||
# Truncate the chosen and rejected responses if needed
|
||||
if len_chosen > max_response_len:
|
||||
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["chosen"] = tokenizer.decode(
|
||||
chosen_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
if len_rejected > max_response_len:
|
||||
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:max_response_len]
|
||||
sample["rejected"] = tokenizer.decode(
|
||||
rejected_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
elif rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
@@ -171,12 +219,86 @@ def _drop_long_sequences(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
# Truncate first
|
||||
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
|
||||
if handling_mode == "truncate":
|
||||
# If sequence fits, return sample unchanged
|
||||
if (len_prompt + len_completion) <= sequence_len:
|
||||
result = sample
|
||||
else:
|
||||
# Calculate maximum completion length
|
||||
max_completion_len = sequence_len - len_prompt
|
||||
|
||||
if rl is RLType.GRPO:
|
||||
return True
|
||||
if max_completion_len <= 0:
|
||||
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
|
||||
LOG.warning(
|
||||
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
|
||||
len_prompt,
|
||||
sequence_len,
|
||||
)
|
||||
result = False
|
||||
else:
|
||||
# Truncate the completion if needed
|
||||
if len_completion > max_completion_len:
|
||||
completion_tokens = tokenizer(
|
||||
completion, add_special_tokens=False
|
||||
)["input_ids"][:max_completion_len]
|
||||
sample["completion"] = tokenizer.decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
result = sample
|
||||
else: # handling == "drop"
|
||||
result = (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
elif rl == RLType.GRPO:
|
||||
# For GRPO always keep
|
||||
result = True
|
||||
else:
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
return bool(result)
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
|
||||
"""
|
||||
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
|
||||
Used with dataset.filter() after truncation to drop unsalvageable samples.
|
||||
"""
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
rejected = sample["rejected"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(
|
||||
tokenizer(rejected, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
if rl == RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
completion = sample["completion"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_completion = len(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
if rl == RLType.GRPO:
|
||||
# GRPO does not enforce this check here
|
||||
return True
|
||||
return False
|
||||
|
||||
# Legacy shim preserved for backward compatibility; no-op in new flow
|
||||
def load_split(dataset_cfgs, _cfg): # noqa: F811
|
||||
return None
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
|
||||
@@ -15,10 +15,12 @@ from datasets import Dataset, IterableDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""Enum for retry strategies."""
|
||||
@@ -168,10 +170,19 @@ def drop_long_seq_in_dataset(
|
||||
)
|
||||
return dataset
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
# Get the handling method from config, default to "drop" for backward compatibility.
|
||||
# Support legacy alias "excess_token_handling" as well.
|
||||
handling = cfg.get(
|
||||
"sequence_len_overflow_handling",
|
||||
cfg.get("excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING),
|
||||
)
|
||||
|
||||
# Use the function with the specified handling mode
|
||||
seq_handler = functools.partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
handling=handling,
|
||||
)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
@@ -190,17 +201,31 @@ def drop_long_seq_in_dataset(
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
if handling == "truncate":
|
||||
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||
else: # handling == "drop"
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
if handling == "truncate":
|
||||
# Use map for truncate mode
|
||||
dataset = dataset.map(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long samples in dataset to {sequence_len} tokens")
|
||||
else: # handling == "drop"
|
||||
# Use filter for drop mode
|
||||
dataset = dataset.filter(
|
||||
seq_handler,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -414,6 +414,12 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
|
||||
default="drop",
|
||||
json_schema_extra={
|
||||
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -233,6 +233,114 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
return results
|
||||
|
||||
|
||||
def truncate_or_drop_long_seq(
|
||||
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
|
||||
):
|
||||
"""
|
||||
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
If handling is "drop":
|
||||
- Samples that are too short or too long will be dropped
|
||||
If handling is "truncate":
|
||||
- Samples that are too short will still be dropped
|
||||
- Samples that are too long will be truncated to sequence_len
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
result = None
|
||||
|
||||
if handling == "drop":
|
||||
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Edge case: if input_ids is empty
|
||||
if not input_ids:
|
||||
result = False if handling == "drop" else sample
|
||||
# Single example (input_ids is a list of int)
|
||||
elif isinstance(input_ids[0], int):
|
||||
length = len(input_ids)
|
||||
|
||||
# Handle samples that are too short - always drop them
|
||||
if length < min_sequence_len:
|
||||
result = False if handling == "drop" else sample
|
||||
# If truncation is enabled and the sample is too long, truncate it
|
||||
elif length > sequence_len and handling == "truncate":
|
||||
sample["input_ids"] = input_ids[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"] = sample["labels"][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"] = sample["position_ids"][:sequence_len]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"] = sequence_len
|
||||
|
||||
result = sample
|
||||
# For drop mode or if the sample doesn't exceed max length
|
||||
else:
|
||||
result = (
|
||||
min_sequence_len <= length <= sequence_len
|
||||
if handling == "drop"
|
||||
else sample
|
||||
)
|
||||
# Batched (input_ids is a list of lists)
|
||||
else:
|
||||
if handling == "drop":
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
result = results
|
||||
else: # truncate
|
||||
# Check each sequence in the batch
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
|
||||
# Skip sequences that are too short
|
||||
if length < min_sequence_len:
|
||||
continue
|
||||
|
||||
# Truncate sequences that are too long
|
||||
if length > sequence_len:
|
||||
input_ids[i] = seq[:sequence_len]
|
||||
|
||||
# Also truncate attention_mask if present
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Also truncate labels if present
|
||||
if "labels" in sample:
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
|
||||
# Also truncate position_ids if present
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][
|
||||
:sequence_len
|
||||
]
|
||||
|
||||
# Update length if present
|
||||
if "length" in sample:
|
||||
sample["length"][i] = sequence_len
|
||||
|
||||
result = sample
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||
if drop_attn_mask:
|
||||
@@ -368,15 +476,33 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
|
||||
|
||||
def process_pretraining_datasets_for_packing(
|
||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||
train_dataset,
|
||||
sequence_len,
|
||||
skip_position_ids=True,
|
||||
drop_attention_mask=False,
|
||||
handling="drop",
|
||||
):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
# Define the function to use for handling sequences based on the mode
|
||||
seq_handler_fn = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
handling=handling, # Pass handling mode
|
||||
)
|
||||
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
train_dataset = train_dataset.map(
|
||||
seq_handler_fn,
|
||||
desc="Truncating Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
else: # handling == "drop"
|
||||
train_dataset = train_dataset.filter(
|
||||
seq_handler_fn, # Use the same function, it returns boolean for drop mode
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
if not skip_position_ids:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
|
||||
@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_pretraining, md5
|
||||
from axolotl.utils.data.rl import drop_long_rl_seq
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -64,5 +66,254 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestDropLongRLSeq(unittest.TestCase):
|
||||
"""
|
||||
Tests for the drop_long_rl_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Mock tokenizer that returns length based on input string length
|
||||
self.tokenizer = MagicMock()
|
||||
|
||||
def side_effect_func(
|
||||
text, add_special_tokens=False
|
||||
): # pylint: disable=unused-argument
|
||||
return {"input_ids": list(range(len(text)))}
|
||||
|
||||
self.tokenizer.side_effect = side_effect_func
|
||||
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
|
||||
["x"] * len(tokens)
|
||||
) # pylint: disable=unused-argument
|
||||
|
||||
self.sequence_len = 20
|
||||
|
||||
def test_dpo_drop_mode_valid(self):
|
||||
"""Test DPO drop mode with a valid sample."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_chosen(self):
|
||||
"""Test DPO drop mode with chosen too long."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 16,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_rejected(self):
|
||||
"""Test DPO drop mode with rejected too long."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 16,
|
||||
} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed(self):
|
||||
"""Test DPO truncate mode when no truncation is needed."""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
"rejected": "r" * 6,
|
||||
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(
|
||||
result, original_sample
|
||||
) # Should return the original sample unchanged
|
||||
|
||||
def test_dpo_truncate_mode_prompt_too_long(self):
|
||||
"""Test DPO truncate mode when the prompt itself is too long."""
|
||||
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
# Even though truncation isn't possible, the function should return the original sample
|
||||
# for the map operation, assuming downstream filtering will catch it.
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_dpo_truncate_mode_chosen_truncated(self):
|
||||
"""Test DPO truncate mode when only 'chosen' needs truncation."""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 18,
|
||||
"rejected": "r" * 10,
|
||||
} # 5+18=23 > 20, 5+10=15 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["chosen"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
self.assertEqual(len(result["rejected"]), 10) # Unchanged
|
||||
|
||||
def test_dpo_truncate_mode_rejected_truncated(self):
|
||||
"""Test DPO truncate mode when only 'rejected' needs truncation."""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 15
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 10,
|
||||
"rejected": "r" * 18,
|
||||
} # 5+10=15 <= 20, 5+18=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), 10) # Unchanged
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
|
||||
self.assertEqual(
|
||||
result["rejected"], "x" * max_resp_len
|
||||
) # Check decoded truncated value
|
||||
|
||||
def test_dpo_truncate_mode_both_truncated(self):
|
||||
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
|
||||
prompt_len = 8
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 15,
|
||||
"rejected": "r" * 14,
|
||||
} # 8+15=23 > 20, 8+14=22 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
|
||||
self.assertEqual(result["rejected"], "x" * max_resp_len)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
|
||||
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
|
||||
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
|
||||
# but the initial check failed because e.g. prompt + chosen > sequence_len
|
||||
# The current logic *will* truncate if len(chosen) > max_resp_len.
|
||||
# Let's test a case where one is slightly too long causing the initial fail,
|
||||
# but the other fits *within* the max_response_len, so only one gets truncated.
|
||||
prompt_len = 10
|
||||
max_resp_len = self.sequence_len - prompt_len # 10
|
||||
sample = {
|
||||
"prompt": "p" * prompt_len,
|
||||
"chosen": "c" * 11,
|
||||
"rejected": "r" * 9,
|
||||
} # 10+11=21 > 20, 10+9=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
|
||||
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
|
||||
|
||||
# Add similar tests for KTO if needed, checking prompt + completion length
|
||||
|
||||
def test_kto_drop_mode_valid(self):
|
||||
"""Test KTO drop mode with a valid sample."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_kto_drop_mode_invalid(self):
|
||||
"""Test KTO drop mode with an invalid sample."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_kto_truncate_mode_no_truncation_needed(self):
|
||||
"""Test KTO truncate mode when no truncation is needed."""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_kto_truncate_mode_prompt_too_long(self):
|
||||
"""Test KTO truncate mode when the prompt itself is too long."""
|
||||
sample = {"prompt": "p" * 25, "completion": "c" * 7}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, original_sample) # Returns original sample
|
||||
|
||||
def test_kto_truncate_mode_completion_truncated(self):
|
||||
"""Test KTO truncate mode when completion needs truncation."""
|
||||
prompt_len = 8
|
||||
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
|
||||
self.assertEqual(result["completion"], "x" * max_comp_len)
|
||||
|
||||
def test_missing_keys_dpo(self):
|
||||
"""Test ValueError raised if keys missing for DPO."""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt, chosen and rejected keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_missing_keys_kto(self):
|
||||
"""Test ValueError raised if keys missing for KTO."""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt and completion keys are required"
|
||||
):
|
||||
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_unknown_rl_type(self):
|
||||
"""Test ValueError raised for unknown RL type."""
|
||||
sample = {}
|
||||
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
|
||||
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
|
||||
|
||||
# GRPO test - current implementation always passes
|
||||
def test_grpo_drop(self):
|
||||
"""Test GRPO drop mode (currently always True)."""
|
||||
sample = {}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_grpo_truncate(self):
|
||||
"""Test GRPO truncate mode (currently returns original sample)."""
|
||||
sample = {"a": 1}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
)
|
||||
self.assertEqual(result, sample)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
153
tests/test_trainer_utils.py
Normal file
153
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Module containing tests for trainer utility functions."""
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
|
||||
# Test cases for truncate_or_drop_long_seq
|
||||
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
"""
|
||||
Test suite for truncate_or_drop_long_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""Test drop mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
|
||||
# Too short
|
||||
sample_short = {"input_ids": [1, 2]}
|
||||
self.assertFalse(handler(sample_short))
|
||||
|
||||
# Too long
|
||||
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||
self.assertFalse(handler(sample_long))
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||
self.assertTrue(handler(sample_ok))
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": []}
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""Test truncate mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
|
||||
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||
# but the function itself might return the sample or False based on impl.)
|
||||
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||
result_short = handler(sample_short)
|
||||
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||
|
||||
# Too long
|
||||
original_long = list(range(self.sequence_len + 5))
|
||||
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||
result_long = handler(sample_long)
|
||||
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||
|
||||
# Just right
|
||||
sample_ok = {
|
||||
"input_ids": list(range(self.min_sequence_len)),
|
||||
"labels": list(range(self.min_sequence_len)),
|
||||
}
|
||||
result_ok = handler(sample_ok)
|
||||
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": [], "labels": []}
|
||||
result_empty = handler(sample_empty)
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""Test drop mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 1)), # Too long
|
||||
list(range(self.sequence_len)), # OK (len = 10)
|
||||
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||
[], # Empty
|
||||
]
|
||||
}
|
||||
expected = [False, False, True, True, False]
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""Test truncate mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 5)), # Too long
|
||||
list(range(self.sequence_len)), # OK
|
||||
list(range(self.min_sequence_len)), # OK
|
||||
[], # Empty
|
||||
],
|
||||
"labels": [ # Add labels to test truncation
|
||||
[1, 2],
|
||||
list(range(self.sequence_len + 5)),
|
||||
list(range(self.sequence_len)),
|
||||
list(range(self.min_sequence_len)),
|
||||
[],
|
||||
],
|
||||
}
|
||||
|
||||
result = handler(sample)
|
||||
|
||||
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||
expected_input_ids = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
expected_labels = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
|
||||
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||
self.assertEqual(result["labels"], expected_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user