* fix: Fix evaluation loss in KD trainer * Fix v2 strategy super() call * fix: Add safety check for total_tokens in log method * fix: simplified num items and outputs return handling * fix: add missing model forward pass in compute_loss * refactor: Use Template Method pattern for chat template strategies * refactor: use pop(None) and remove v2 override * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
"""
|
|
Test for KD chat template strategies
|
|
"""
|
|
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
|
|
|
|
|
|
class TestChatTemplateStrategyWithKDv2:
|
|
"""Test v2 strategy correctly handles target_token_ids"""
|
|
|
|
@pytest.fixture
|
|
def v2_strategy(self):
|
|
"""Create v2 strategy instance with mocked dependencies"""
|
|
# Mock prompter
|
|
mock_prompter = Mock()
|
|
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
|
|
mock_prompter.chat_template_msg_variables = ["role", "content"]
|
|
mock_prompter.chat_template = "{{ messages }}"
|
|
|
|
# Mock tokenizer
|
|
mock_tokenizer = Mock()
|
|
mock_tokenizer.pad_token_id = 0
|
|
mock_tokenizer.eos_token_id = 2
|
|
mock_tokenizer.bos_token_id = 1
|
|
mock_tokenizer.eos_token = "<|endoftext|>"
|
|
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
|
|
mock_tokenizer.encode = Mock(return_value=[2])
|
|
|
|
return ChatTemplateStrategyWithKDv2(
|
|
prompter=mock_prompter,
|
|
tokenizer=mock_tokenizer,
|
|
train_on_inputs=False,
|
|
sequence_len=512,
|
|
logprobs_field="logprobs",
|
|
gen_temperature=1.0,
|
|
kd_temperature=1.0,
|
|
)
|
|
|
|
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
|
|
"""
|
|
Test that v2's _prepare_kd_fields hook adds target_token_ids.
|
|
|
|
Validates the Template Method pattern fix where v2 overrides
|
|
the hook to add target_token_ids before transform.
|
|
"""
|
|
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
|
original = {"target_token_ids": [[10, 20], [30, 40]]}
|
|
|
|
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
|
|
|
assert "target_token_ids" in result
|
|
assert result["target_token_ids"] == [[10, 20], [30, 40]]
|
|
|
|
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
|
|
"""Test hook handles missing target_token_ids gracefully"""
|
|
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
|
original = {}
|
|
|
|
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
|
|
|
assert "target_token_ids" not in result
|
|
|
|
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
|
|
"""
|
|
Test v2's transform fails without target_token_ids.
|
|
|
|
Validates the bug fix - transform expects target_token_ids
|
|
to be added by the hook.
|
|
"""
|
|
sample = {
|
|
"input_ids": [1, 10, 20, 30, 2],
|
|
"labels": [1, 10, 20, 30, 2],
|
|
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
|
|
}
|
|
|
|
with pytest.raises(KeyError, match="target_token_ids"):
|
|
v2_strategy.transform_logprobs(sample)
|