fix: Fix evaluation loss in KD trainer (#3271)
* fix: Fix evaluation loss in KD trainer * Fix v2 strategy super() call * fix: Add safety check for total_tokens in log method * fix: simplified num items and outputs return handling * fix: add missing model forward pass in compute_loss * refactor: Use Template Method pattern for chat template strategies * refactor: use pop(None) and remove v2 override * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -631,7 +631,11 @@ class AxolotlTrainer(
|
|||||||
logs["tokens_per_second_per_gpu"] = round(
|
logs["tokens_per_second_per_gpu"] = round(
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
)
|
)
|
||||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
if (
|
||||||
|
hasattr(self.state, "total_tokens")
|
||||||
|
and self.state.total_tokens is not None
|
||||||
|
):
|
||||||
|
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
|||||||
@@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
logprobs = prompt.pop(self.logprobs_field)
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
tokenized_prompt[self.logprobs_field] = logprobs
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
|
||||||
|
|
||||||
|
# let subclasses add fields before transform
|
||||||
|
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
|
||||||
|
|
||||||
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
|
return tokenized_prompt
|
||||||
|
|
||||||
|
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||||
|
"""
|
||||||
|
Hook for subclasses to prepare additional KD fields before transform
|
||||||
|
"""
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
@@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _tokenize_single_prompt(self, prompt):
|
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||||
target_token_ids = prompt.get("target_token_ids", None)
|
"""
|
||||||
|
Add pre-tokenized target_token_ids for v2 format
|
||||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
"""
|
||||||
|
target_token_ids = original_prompt.pop("target_token_ids", None)
|
||||||
if target_token_ids is not None:
|
if target_token_ids is not None:
|
||||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
KD trainer
|
KD trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||||
@@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
if columns_to_add:
|
if columns_to_add:
|
||||||
self._signature_columns += columns_to_add
|
self._signature_columns += columns_to_add
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
):
|
):
|
||||||
del inputs["attention_mask"]
|
del inputs["attention_mask"]
|
||||||
|
|
||||||
|
if num_items_in_batch is None and "labels" in inputs:
|
||||||
|
num_items_in_batch = (inputs["labels"] != -100).sum().item()
|
||||||
|
|
||||||
if self.model_accepts_loss_kwargs:
|
if self.model_accepts_loss_kwargs:
|
||||||
loss_kwargs = {}
|
loss_kwargs = {}
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None:
|
||||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||||
inputs = {**inputs, **loss_kwargs}
|
inputs = {**inputs, **loss_kwargs}
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
return outputs[0]
|
|
||||||
|
if isinstance(outputs, dict):
|
||||||
|
loss = outputs["loss"]
|
||||||
|
elif isinstance(outputs, tuple):
|
||||||
|
loss = outputs[0]
|
||||||
|
else:
|
||||||
|
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
||||||
|
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|||||||
81
tests/integrations/test_kd_chat_template.py
Normal file
81
tests/integrations/test_kd_chat_template.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
Test for KD chat template strategies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateStrategyWithKDv2:
|
||||||
|
"""Test v2 strategy correctly handles target_token_ids"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def v2_strategy(self):
|
||||||
|
"""Create v2 strategy instance with mocked dependencies"""
|
||||||
|
# Mock prompter
|
||||||
|
mock_prompter = Mock()
|
||||||
|
mock_prompter.roles = {"user": "user", "assistant": "assistant"}
|
||||||
|
mock_prompter.chat_template_msg_variables = ["role", "content"]
|
||||||
|
mock_prompter.chat_template = "{{ messages }}"
|
||||||
|
|
||||||
|
# Mock tokenizer
|
||||||
|
mock_tokenizer = Mock()
|
||||||
|
mock_tokenizer.pad_token_id = 0
|
||||||
|
mock_tokenizer.eos_token_id = 2
|
||||||
|
mock_tokenizer.bos_token_id = 1
|
||||||
|
mock_tokenizer.eos_token = "<|endoftext|>"
|
||||||
|
mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2])
|
||||||
|
mock_tokenizer.encode = Mock(return_value=[2])
|
||||||
|
|
||||||
|
return ChatTemplateStrategyWithKDv2(
|
||||||
|
prompter=mock_prompter,
|
||||||
|
tokenizer=mock_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
logprobs_field="logprobs",
|
||||||
|
gen_temperature=1.0,
|
||||||
|
kd_temperature=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy):
|
||||||
|
"""
|
||||||
|
Test that v2's _prepare_kd_fields hook adds target_token_ids.
|
||||||
|
|
||||||
|
Validates the Template Method pattern fix where v2 overrides
|
||||||
|
the hook to add target_token_ids before transform.
|
||||||
|
"""
|
||||||
|
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
||||||
|
original = {"target_token_ids": [[10, 20], [30, 40]]}
|
||||||
|
|
||||||
|
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
||||||
|
|
||||||
|
assert "target_token_ids" in result
|
||||||
|
assert result["target_token_ids"] == [[10, 20], [30, 40]]
|
||||||
|
|
||||||
|
def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy):
|
||||||
|
"""Test hook handles missing target_token_ids gracefully"""
|
||||||
|
tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]}
|
||||||
|
original = {}
|
||||||
|
|
||||||
|
result = v2_strategy._prepare_kd_fields(tokenized, original)
|
||||||
|
|
||||||
|
assert "target_token_ids" not in result
|
||||||
|
|
||||||
|
def test_v2_transform_requires_target_token_ids(self, v2_strategy):
|
||||||
|
"""
|
||||||
|
Test v2's transform fails without target_token_ids.
|
||||||
|
|
||||||
|
Validates the bug fix - transform expects target_token_ids
|
||||||
|
to be added by the hook.
|
||||||
|
"""
|
||||||
|
sample = {
|
||||||
|
"input_ids": [1, 10, 20, 30, 2],
|
||||||
|
"labels": [1, 10, 20, 30, 2],
|
||||||
|
"logprobs": [[-0.1, -0.2], [-0.3, -0.4]],
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(KeyError, match="target_token_ids"):
|
||||||
|
v2_strategy.transform_logprobs(sample)
|
||||||
Reference in New Issue
Block a user