📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length

Docstrings generation was requested by @mhenrichsen.

* https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776

The following files were modified:

* `src/axolotl/utils/data/pretraining.py`
* `src/axolotl/utils/data/rl.py`
* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
* `tests/test_data.py`
* `tests/test_trainer_utils.py`
This commit is contained in:
coderabbitai[bot]
2025-05-15 11:02:45 +00:00
committed by GitHub
parent 5d7a61576d
commit e23a5c9fda
6 changed files with 215 additions and 38 deletions

View File

@@ -60,6 +60,9 @@ class TestEncodePretraining(unittest.TestCase):
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
"""
Tests that the md5 function returns the correct hash for a given string and encoding.
"""
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
@@ -73,11 +76,26 @@ class TestDropLongRLSeq(unittest.TestCase):
def setUp(self):
# Mock tokenizer that returns length based on input string length
"""
Sets up a mock tokenizer and sequence length for RL sequence length tests.
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
"""
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
"""
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
Args:
text: The input string to tokenize.
add_special_tokens: Ignored parameter included for interface compatibility.
Returns:
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
"""
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
@@ -88,7 +106,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""Test DPO drop mode with a valid sample."""
"""
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -100,7 +120,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""Test DPO drop mode with chosen too long."""
"""
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
@@ -112,7 +134,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""Test DPO drop mode with rejected too long."""
"""
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -124,7 +148,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""Test DPO truncate mode when no truncation is needed."""
"""
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -139,7 +165,10 @@ class TestDropLongRLSeq(unittest.TestCase):
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""Test DPO truncate mode when the prompt itself is too long."""
"""
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -150,7 +179,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""Test DPO truncate mode when only 'chosen' needs truncation."""
"""
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
@@ -169,7 +200,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""Test DPO truncate mode when only 'rejected' needs truncation."""
"""
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
@@ -188,7 +221,11 @@ class TestDropLongRLSeq(unittest.TestCase):
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
"""
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
"""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
@@ -206,7 +243,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
"""
Tests DPO truncate mode where only the overlong response is truncated.
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
"""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
@@ -230,7 +271,9 @@ class TestDropLongRLSeq(unittest.TestCase):
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""Test KTO drop mode with a valid sample."""
"""
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
@@ -238,7 +281,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""Test KTO drop mode with an invalid sample."""
"""
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
@@ -246,7 +291,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""Test KTO truncate mode when no truncation is needed."""
"""
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -255,7 +302,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""Test KTO truncate mode when the prompt itself is too long."""
"""
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -264,7 +313,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""Test KTO truncate mode when completion needs truncation."""
"""
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
"""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
@@ -276,7 +329,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""Test ValueError raised if keys missing for DPO."""
"""
Tests that a ValueError is raised when required keys are missing for DPO samples.
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
@@ -284,7 +341,12 @@ class TestDropLongRLSeq(unittest.TestCase):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""Test ValueError raised if keys missing for KTO."""
"""
Tests that a ValueError is raised when required keys are missing for RL type "kto".
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
a ValueError with the expected error message.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
@@ -292,14 +354,18 @@ class TestDropLongRLSeq(unittest.TestCase):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""Test ValueError raised for unknown RL type."""
"""
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
"""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""Test GRPO drop mode (currently always True)."""
"""
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
"""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
@@ -307,7 +373,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_grpo_truncate(self):
"""Test GRPO truncate mode (currently returns original sample)."""
"""
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
"""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"

View File

@@ -14,11 +14,19 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
def setUp(self):
# Example sequence length settings
"""
Sets up default sequence length parameters for the test cases.
"""
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
"""
Verifies that 'drop' mode correctly filters single sequence examples based on length.
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
while sequences within the valid length range are kept.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -43,7 +51,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
"""
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -83,7 +95,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
"""
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -103,7 +119,13 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
"""
Tests that batched examples are correctly truncated in "truncate" mode.
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
allowed length are truncated, while sequences that are too short or empty remain
unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,