fix: Fix evaluation loss in KD trainer (#3271)

* fix: Fix evaluation loss in KD trainer

* Fix v2 strategy super() call

* fix: Add safety check for total_tokens in log method

* fix: simplified num items and outputs return handling

* fix: add missing model forward pass in compute_loss

* refactor: Use Template Method pattern for chat template strategies

* refactor: use pop(None) and remove v2 override

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Seung Hyun Cho
2025-12-18 03:40:36 +09:00
committed by GitHub
parent 2cf254b4af
commit 3e51a680c2
4 changed files with 117 additions and 9 deletions

View File

@@ -631,7 +631,11 @@ class AxolotlTrainer(
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
logs["total_tokens"] = int(self.state.total_tokens.item())
if (
hasattr(self.state, "total_tokens")
and self.state.total_tokens is not None
):
logs["total_tokens"] = int(self.state.total_tokens.item())
del self._stored_metrics[train_eval]

View File

@@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
# let subclasses add fields before transform
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Hook for subclasses to prepare additional KD fields before transform
"""
return tokenized_prompt
@@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
target_token_ids = prompt.get("target_token_ids", None)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
"""
Add pre-tokenized target_token_ids for v2 format
"""
target_token_ids = original_prompt.pop("target_token_ids", None)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
return tokenized_prompt

View File

@@ -16,6 +16,8 @@
KD trainer
"""
from typing_extensions import override
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
@@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
if columns_to_add:
self._signature_columns += columns_to_add
@override
def compute_loss(
self,
model,
@@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
):
del inputs["attention_mask"]
if num_items_in_batch is None and "labels" in inputs:
num_items_in_batch = (inputs["labels"] != -100).sum().item()
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
if isinstance(outputs, dict):
loss = outputs["loss"]
elif isinstance(outputs, tuple):
loss = outputs[0]
else:
loss = outputs.loss if hasattr(outputs, "loss") else outputs
return (loss, outputs) if return_outputs else loss

View File

@@ -0,0 +1,81 @@
"""
Test for KD chat template strategies
"""
from unittest.mock import Mock
import pytest
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
class TestChatTemplateStrategyWithKDv2:
"""Test v2 strategy correctly handles target_token_ids"""
@pytest.fixture
def v2_strategy(self):
"""Create v2 strategy instance with mocked dependencies"""
# Mock prompter
mock_prompter = Mock()
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
mock_prompter.chat_template_msg_variables = ["role", "content"]
mock_prompter.chat_template = "{{ messages }}"
# Mock tokenizer
mock_tokenizer = Mock()
mock_tokenizer.pad_token_id = 0
mock_tokenizer.eos_token_id = 2
mock_tokenizer.bos_token_id = 1
mock_tokenizer.eos_token = "<|endoftext|>"
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
mock_tokenizer.encode = Mock(return_value=[2])
return ChatTemplateStrategyWithKDv2(
prompter=mock_prompter,
tokenizer=mock_tokenizer,
train_on_inputs=False,
sequence_len=512,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
)
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
"""
Test that v2's _prepare_kd_fields hook adds target_token_ids.
Validates the Template Method pattern fix where v2 overrides
the hook to add target_token_ids before transform.
"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {"target_token_ids": [[10, 20], [30, 40]]}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" in result
assert result["target_token_ids"] == [[10, 20], [30, 40]]
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
"""Test hook handles missing target_token_ids gracefully"""
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
original = {}
result = v2_strategy._prepare_kd_fields(tokenized, original)
assert "target_token_ids" not in result
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
"""
Test v2's transform fails without target_token_ids.
Validates the bug fix - transform expects target_token_ids
to be added by the hook.
"""
sample = {
"input_ids": [1, 10, 20, 30, 2],
"labels": [1, 10, 20, 30, 2],
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
}
with pytest.raises(KeyError, match="target_token_ids"):
v2_strategy.transform_logprobs(sample)