increased test coverage
This commit is contained in:
@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import encode_pretraining, md5
|
from axolotl.utils.data import encode_pretraining, md5
|
||||||
|
from axolotl.utils.data.rl import drop_long_rl_seq
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
@@ -64,5 +66,254 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDropLongRLSeq(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Tests for the drop_long_rl_seq function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# Mock tokenizer that returns length based on input string length
|
||||||
|
self.tokenizer = MagicMock()
|
||||||
|
|
||||||
|
def side_effect_func(
|
||||||
|
text, add_special_tokens=False
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
return {"input_ids": list(range(len(text)))}
|
||||||
|
|
||||||
|
self.tokenizer.side_effect = side_effect_func
|
||||||
|
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
|
||||||
|
["x"] * len(tokens)
|
||||||
|
) # pylint: disable=unused-argument
|
||||||
|
|
||||||
|
self.sequence_len = 20
|
||||||
|
|
||||||
|
def test_dpo_drop_mode_valid(self):
|
||||||
|
"""Test DPO drop mode with a valid sample."""
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * 5,
|
||||||
|
"chosen": "c" * 7,
|
||||||
|
"rejected": "r" * 6,
|
||||||
|
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_dpo_drop_mode_invalid_chosen(self):
|
||||||
|
"""Test DPO drop mode with chosen too long."""
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * 5,
|
||||||
|
"chosen": "c" * 16,
|
||||||
|
"rejected": "r" * 6,
|
||||||
|
} # 5+16=21 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_dpo_drop_mode_invalid_rejected(self):
|
||||||
|
"""Test DPO drop mode with rejected too long."""
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * 5,
|
||||||
|
"chosen": "c" * 7,
|
||||||
|
"rejected": "r" * 16,
|
||||||
|
} # 5+16=21 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_no_truncation_needed(self):
|
||||||
|
"""Test DPO truncate mode when no truncation is needed."""
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * 5,
|
||||||
|
"chosen": "c" * 7,
|
||||||
|
"rejected": "r" * 6,
|
||||||
|
} # 5+7=12 <= 20, 5+6=11 <= 20
|
||||||
|
original_sample = sample.copy()
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
result, original_sample
|
||||||
|
) # Should return the original sample unchanged
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_prompt_too_long(self):
|
||||||
|
"""Test DPO truncate mode when the prompt itself is too long."""
|
||||||
|
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
|
||||||
|
original_sample = sample.copy()
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
# Even though truncation isn't possible, the function should return the original sample
|
||||||
|
# for the map operation, assuming downstream filtering will catch it.
|
||||||
|
self.assertEqual(result, original_sample)
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_chosen_truncated(self):
|
||||||
|
"""Test DPO truncate mode when only 'chosen' needs truncation."""
|
||||||
|
prompt_len = 5
|
||||||
|
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * prompt_len,
|
||||||
|
"chosen": "c" * 18,
|
||||||
|
"rejected": "r" * 10,
|
||||||
|
} # 5+18=23 > 20, 5+10=15 <= 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||||
|
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
|
||||||
|
self.assertEqual(
|
||||||
|
result["chosen"], "x" * max_resp_len
|
||||||
|
) # Check decoded truncated value
|
||||||
|
self.assertEqual(len(result["rejected"]), 10) # Unchanged
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_rejected_truncated(self):
|
||||||
|
"""Test DPO truncate mode when only 'rejected' needs truncation."""
|
||||||
|
prompt_len = 5
|
||||||
|
max_resp_len = self.sequence_len - prompt_len # 15
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * prompt_len,
|
||||||
|
"chosen": "c" * 10,
|
||||||
|
"rejected": "r" * 18,
|
||||||
|
} # 5+10=15 <= 20, 5+18=23 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||||
|
self.assertEqual(len(result["chosen"]), 10) # Unchanged
|
||||||
|
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
|
||||||
|
self.assertEqual(
|
||||||
|
result["rejected"], "x" * max_resp_len
|
||||||
|
) # Check decoded truncated value
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_both_truncated(self):
|
||||||
|
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
|
||||||
|
prompt_len = 8
|
||||||
|
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * prompt_len,
|
||||||
|
"chosen": "c" * 15,
|
||||||
|
"rejected": "r" * 14,
|
||||||
|
} # 8+15=23 > 20, 8+14=22 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||||
|
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
|
||||||
|
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||||
|
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
|
||||||
|
self.assertEqual(result["rejected"], "x" * max_resp_len)
|
||||||
|
|
||||||
|
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
|
||||||
|
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
|
||||||
|
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
|
||||||
|
# but the initial check failed because e.g. prompt + chosen > sequence_len
|
||||||
|
# The current logic *will* truncate if len(chosen) > max_resp_len.
|
||||||
|
# Let's test a case where one is slightly too long causing the initial fail,
|
||||||
|
# but the other fits *within* the max_response_len, so only one gets truncated.
|
||||||
|
prompt_len = 10
|
||||||
|
max_resp_len = self.sequence_len - prompt_len # 10
|
||||||
|
sample = {
|
||||||
|
"prompt": "p" * prompt_len,
|
||||||
|
"chosen": "c" * 11,
|
||||||
|
"rejected": "r" * 9,
|
||||||
|
} # 10+11=21 > 20, 10+9=19 <= 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||||
|
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
|
||||||
|
self.assertEqual(result["chosen"], "x" * max_resp_len)
|
||||||
|
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
|
||||||
|
|
||||||
|
# Add similar tests for KTO if needed, checking prompt + completion length
|
||||||
|
|
||||||
|
def test_kto_drop_mode_valid(self):
|
||||||
|
"""Test KTO drop mode with a valid sample."""
|
||||||
|
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_kto_drop_mode_invalid(self):
|
||||||
|
"""Test KTO drop mode with an invalid sample."""
|
||||||
|
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_kto_truncate_mode_no_truncation_needed(self):
|
||||||
|
"""Test KTO truncate mode when no truncation is needed."""
|
||||||
|
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
|
||||||
|
original_sample = sample.copy()
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, original_sample)
|
||||||
|
|
||||||
|
def test_kto_truncate_mode_prompt_too_long(self):
|
||||||
|
"""Test KTO truncate mode when the prompt itself is too long."""
|
||||||
|
sample = {"prompt": "p" * 25, "completion": "c" * 7}
|
||||||
|
original_sample = sample.copy()
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, original_sample) # Returns original sample
|
||||||
|
|
||||||
|
def test_kto_truncate_mode_completion_truncated(self):
|
||||||
|
"""Test KTO truncate mode when completion needs truncation."""
|
||||||
|
prompt_len = 8
|
||||||
|
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
|
||||||
|
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(len(result["prompt"]), prompt_len)
|
||||||
|
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
|
||||||
|
self.assertEqual(result["completion"], "x" * max_comp_len)
|
||||||
|
|
||||||
|
def test_missing_keys_dpo(self):
|
||||||
|
"""Test ValueError raised if keys missing for DPO."""
|
||||||
|
sample = {"prompt": "p"}
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "Prompt, chosen and rejected keys are required"
|
||||||
|
):
|
||||||
|
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
|
||||||
|
|
||||||
|
def test_missing_keys_kto(self):
|
||||||
|
"""Test ValueError raised if keys missing for KTO."""
|
||||||
|
sample = {"prompt": "p"}
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "Prompt and completion keys are required"
|
||||||
|
):
|
||||||
|
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
|
||||||
|
|
||||||
|
def test_unknown_rl_type(self):
|
||||||
|
"""Test ValueError raised for unknown RL type."""
|
||||||
|
sample = {}
|
||||||
|
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
|
||||||
|
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
|
||||||
|
|
||||||
|
# GRPO test - current implementation always passes
|
||||||
|
def test_grpo_drop(self):
|
||||||
|
"""Test GRPO drop mode (currently always True)."""
|
||||||
|
sample = {}
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_grpo_truncate(self):
|
||||||
|
"""Test GRPO truncate mode (currently returns original sample)."""
|
||||||
|
sample = {"a": 1}
|
||||||
|
result = drop_long_rl_seq(
|
||||||
|
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user