📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length
Docstrings generation was requested by @mhenrichsen. * https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776 The following files were modified: * `src/axolotl/utils/data/pretraining.py` * `src/axolotl/utils/data/rl.py` * `src/axolotl/utils/data/utils.py` * `src/axolotl/utils/trainer.py` * `tests/test_data.py` * `tests/test_trainer_utils.py`
This commit is contained in:
committed by
GitHub
parent
5d7a61576d
commit
e23a5c9fda
@@ -60,6 +60,9 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
||||
|
||||
def test_md5(self):
|
||||
"""
|
||||
Tests that the md5 function returns the correct hash for a given string and encoding.
|
||||
"""
|
||||
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
||||
self.assertEqual(
|
||||
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||
@@ -73,11 +76,26 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock tokenizer that returns length based on input string length
|
||||
"""
|
||||
Sets up a mock tokenizer and sequence length for RL sequence length tests.
|
||||
|
||||
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
|
||||
"""
|
||||
self.tokenizer = MagicMock()
|
||||
|
||||
def side_effect_func(
|
||||
text, add_special_tokens=False
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
|
||||
|
||||
Args:
|
||||
text: The input string to tokenize.
|
||||
add_special_tokens: Ignored parameter included for interface compatibility.
|
||||
|
||||
Returns:
|
||||
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
|
||||
"""
|
||||
return {"input_ids": list(range(len(text)))}
|
||||
|
||||
self.tokenizer.side_effect = side_effect_func
|
||||
@@ -88,7 +106,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.sequence_len = 20
|
||||
|
||||
def test_dpo_drop_mode_valid(self):
|
||||
"""Test DPO drop mode with a valid sample."""
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
@@ -100,7 +120,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_chosen(self):
|
||||
"""Test DPO drop mode with chosen too long."""
|
||||
"""
|
||||
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 16,
|
||||
@@ -112,7 +134,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_drop_mode_invalid_rejected(self):
|
||||
"""Test DPO drop mode with rejected too long."""
|
||||
"""
|
||||
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
@@ -124,7 +148,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed(self):
|
||||
"""Test DPO truncate mode when no truncation is needed."""
|
||||
"""
|
||||
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
|
||||
"""
|
||||
sample = {
|
||||
"prompt": "p" * 5,
|
||||
"chosen": "c" * 7,
|
||||
@@ -139,7 +165,10 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
) # Should return the original sample unchanged
|
||||
|
||||
def test_dpo_truncate_mode_prompt_too_long(self):
|
||||
"""Test DPO truncate mode when the prompt itself is too long."""
|
||||
"""
|
||||
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
|
||||
the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
@@ -150,7 +179,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_dpo_truncate_mode_chosen_truncated(self):
|
||||
"""Test DPO truncate mode when only 'chosen' needs truncation."""
|
||||
"""
|
||||
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
|
||||
"""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
|
||||
sample = {
|
||||
@@ -169,7 +200,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(len(result["rejected"]), 10) # Unchanged
|
||||
|
||||
def test_dpo_truncate_mode_rejected_truncated(self):
|
||||
"""Test DPO truncate mode when only 'rejected' needs truncation."""
|
||||
"""
|
||||
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
|
||||
"""
|
||||
prompt_len = 5
|
||||
max_resp_len = self.sequence_len - prompt_len # 15
|
||||
sample = {
|
||||
@@ -188,7 +221,11 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
) # Check decoded truncated value
|
||||
|
||||
def test_dpo_truncate_mode_both_truncated(self):
|
||||
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
|
||||
"""
|
||||
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
|
||||
|
||||
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
|
||||
"""
|
||||
prompt_len = 8
|
||||
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {
|
||||
@@ -206,7 +243,11 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(result["rejected"], "x" * max_resp_len)
|
||||
|
||||
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
|
||||
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
|
||||
"""
|
||||
Tests DPO truncate mode where only the overlong response is truncated.
|
||||
|
||||
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
|
||||
"""
|
||||
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
|
||||
# but the initial check failed because e.g. prompt + chosen > sequence_len
|
||||
# The current logic *will* truncate if len(chosen) > max_resp_len.
|
||||
@@ -230,7 +271,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
# Add similar tests for KTO if needed, checking prompt + completion length
|
||||
|
||||
def test_kto_drop_mode_valid(self):
|
||||
"""Test KTO drop mode with a valid sample."""
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
@@ -238,7 +281,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_kto_drop_mode_invalid(self):
|
||||
"""Test KTO drop mode with an invalid sample."""
|
||||
"""
|
||||
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
|
||||
result = drop_long_rl_seq(
|
||||
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||
@@ -246,7 +291,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_kto_truncate_mode_no_truncation_needed(self):
|
||||
"""Test KTO truncate mode when no truncation is needed."""
|
||||
"""
|
||||
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
|
||||
"""
|
||||
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
@@ -255,7 +302,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(result, original_sample)
|
||||
|
||||
def test_kto_truncate_mode_prompt_too_long(self):
|
||||
"""Test KTO truncate mode when the prompt itself is too long."""
|
||||
"""
|
||||
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"prompt": "p" * 25, "completion": "c" * 7}
|
||||
original_sample = sample.copy()
|
||||
result = drop_long_rl_seq(
|
||||
@@ -264,7 +313,11 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(result, original_sample) # Returns original sample
|
||||
|
||||
def test_kto_truncate_mode_completion_truncated(self):
|
||||
"""Test KTO truncate mode when completion needs truncation."""
|
||||
"""
|
||||
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
|
||||
|
||||
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
|
||||
"""
|
||||
prompt_len = 8
|
||||
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
|
||||
@@ -276,7 +329,11 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertEqual(result["completion"], "x" * max_comp_len)
|
||||
|
||||
def test_missing_keys_dpo(self):
|
||||
"""Test ValueError raised if keys missing for DPO."""
|
||||
"""
|
||||
Tests that a ValueError is raised when required keys are missing for DPO samples.
|
||||
|
||||
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
|
||||
"""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt, chosen and rejected keys are required"
|
||||
@@ -284,7 +341,12 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_missing_keys_kto(self):
|
||||
"""Test ValueError raised if keys missing for KTO."""
|
||||
"""
|
||||
Tests that a ValueError is raised when required keys are missing for RL type "kto".
|
||||
|
||||
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
|
||||
a ValueError with the expected error message.
|
||||
"""
|
||||
sample = {"prompt": "p"}
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Prompt and completion keys are required"
|
||||
@@ -292,14 +354,18 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
|
||||
|
||||
def test_unknown_rl_type(self):
|
||||
"""Test ValueError raised for unknown RL type."""
|
||||
"""
|
||||
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
|
||||
"""
|
||||
sample = {}
|
||||
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
|
||||
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
|
||||
|
||||
# GRPO test - current implementation always passes
|
||||
def test_grpo_drop(self):
|
||||
"""Test GRPO drop mode (currently always True)."""
|
||||
"""
|
||||
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
|
||||
"""
|
||||
sample = {}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||
@@ -307,7 +373,9 @@ class TestDropLongRLSeq(unittest.TestCase):
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_grpo_truncate(self):
|
||||
"""Test GRPO truncate mode (currently returns original sample)."""
|
||||
"""
|
||||
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
|
||||
"""
|
||||
sample = {"a": 1}
|
||||
result = drop_long_rl_seq(
|
||||
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||
|
||||
@@ -14,11 +14,19 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
"""
|
||||
Sets up default sequence length parameters for the test cases.
|
||||
"""
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""Test drop mode with single examples."""
|
||||
"""
|
||||
Verifies that 'drop' mode correctly filters single sequence examples based on length.
|
||||
|
||||
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
|
||||
while sequences within the valid length range are kept.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
@@ -43,7 +51,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""Test truncate mode with single examples."""
|
||||
"""
|
||||
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
|
||||
|
||||
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
@@ -83,7 +95,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""Test drop mode with batched examples."""
|
||||
"""
|
||||
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
|
||||
|
||||
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
@@ -103,7 +119,13 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""Test truncate mode with batched examples."""
|
||||
"""
|
||||
Tests that batched examples are correctly truncated in "truncate" mode.
|
||||
|
||||
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
|
||||
allowed length are truncated, while sequences that are too short or empty remain
|
||||
unchanged.
|
||||
"""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
|
||||
Reference in New Issue
Block a user