Files
axolotl/tests/integrations/test_kd_chat_template.py
Seung Hyun Cho 3e51a680c2 fix: Fix evaluation loss in KD trainer (#3271)
* fix: Fix evaluation loss in KD trainer

* Fix v2 strategy super() call

* fix: Add safety check for total_tokens in log method

* fix: simplified num items and outputs return handling

* fix: add missing model forward pass in compute_loss

* refactor: Use Template Method pattern for chat template strategies

* refactor: use pop(None) and remove v2 override

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-12-17 13:40:36 -05:00

82 lines
2.8 KiB
Python

"""
Test for KD chat template strategies
"""
from unittest.mock import Mock
import pytest
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
class TestChatTemplateStrategyWithKDv2:
"""Test v2 strategy correctly handles target_token_ids"""
@pytest.fixture
def v2_strategy(self):
"""Create v2 strategy instance with mocked dependencies"""
# Mock prompter
mock_prompter = Mock()
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
mock_prompter.chat_template_msg_variables = ["role", "content"]
mock_prompter.chat_template = "{{ messages }}"
# Mock tokenizer
mock_tokenizer = Mock()
mock_tokenizer.pad_token_id = 0
mock_tokenizer.eos_token_id = 2
mock_tokenizer.bos_token_id = 1
mock_tokenizer.eos_token = "<|endoftext|>"
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
mock_tokenizer.encode = Mock(return_value=[2])
return ChatTemplateStrategyWithKDv2(
prompter=mock_prompter,
tokenizer=mock_tokenizer,
train_on_inputs=False,
sequence_len=512,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
)
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
"""
Test that v2's _prepare_kd_fields hook adds target_token_ids.
Validates the Template Method pattern fix where v2 overrides
the hook to add target_token_ids before transform.
"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {"target_token_ids": [[10, 20], [30, 40]]}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" in result
assert result["target_token_ids"] == [[10, 20], [30, 40]]
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
"""Test hook handles missing target_token_ids gracefully"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" not in result
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
"""
Test v2's transform fails without target_token_ids.
Validates the bug fix - transform expects target_token_ids
to be added by the hook.
"""
sample = {
"input_ids": [1, 10, 20, 30, 2],
"labels": [1, 10, 20, 30, 2],
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
}
with pytest.raises(KeyError, match="target_token_ids"):
v2_strategy.transform_logprobs(sample)