diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 7896c6088..f4414d649 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -631,7 +631,11 @@ class AxolotlTrainer( logs["tokens_per_second_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) - logs["total_tokens"] = int(self.state.total_tokens.item()) + if ( + hasattr(self.state, "total_tokens") + and self.state.total_tokens is not None + ): + logs["total_tokens"] = int(self.state.total_tokens.item()) del self._stored_metrics[train_eval] diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 04f0f24a4..5cae69e7c 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): logprobs = prompt.pop(self.logprobs_field) tokenized_prompt = super()._tokenize_single_prompt(prompt) tokenized_prompt[self.logprobs_field] = logprobs - tokenized_prompt = self.transform_logprobs(tokenized_prompt) + # let subclasses add fields before transform + tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt) + + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + return tokenized_prompt + + def _prepare_kd_fields(self, tokenized_prompt, original_prompt): + """ + Hook for subclasses to prepare additional KD fields before transform + """ return tokenized_prompt @@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): return sample - def _tokenize_single_prompt(self, prompt): - target_token_ids = prompt.get("target_token_ids", None) - - tokenized_prompt = super()._tokenize_single_prompt(prompt) - + def _prepare_kd_fields(self, tokenized_prompt, original_prompt): + """ + Add pre-tokenized target_token_ids for v2 format + """ + target_token_ids = original_prompt.pop("target_token_ids", None) if target_token_ids is not None: tokenized_prompt["target_token_ids"] = target_token_ids - return tokenized_prompt diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index 0e98497a7..343d4c6df 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -16,6 +16,8 @@ KD trainer """ +from typing_extensions import override + from axolotl.core.trainers.base import AxolotlTrainer from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss @@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer): if columns_to_add: self._signature_columns += columns_to_add + @override def compute_loss( self, model, @@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer): ): del inputs["attention_mask"] + if num_items_in_batch is None and "labels" in inputs: + num_items_in_batch = (inputs["labels"] != -100).sum().item() + if self.model_accepts_loss_kwargs: loss_kwargs = {} if num_items_in_batch is not None: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) - return outputs[0] + + if isinstance(outputs, dict): + loss = outputs["loss"] + elif isinstance(outputs, tuple): + loss = outputs[0] + else: + loss = outputs.loss if hasattr(outputs, "loss") else outputs + + return (loss, outputs) if return_outputs else loss diff --git a/tests/integrations/test_kd_chat_template.py b/tests/integrations/test_kd_chat_template.py new file mode 100644 index 000000000..b828e6c3d --- /dev/null +++ b/tests/integrations/test_kd_chat_template.py @@ -0,0 +1,81 @@ +""" +Test for KD chat template strategies +""" + +from unittest.mock import Mock + +import pytest + +from axolotl.integrations.kd.chat_template import ChatTemplateStrategyWithKDv2 + + +class TestChatTemplateStrategyWithKDv2: + """Test v2 strategy correctly handles target_token_ids""" + + @pytest.fixture + def v2_strategy(self): + """Create v2 strategy instance with mocked dependencies""" + # Mock prompter + mock_prompter = Mock() + mock_prompter.roles = {"user": "user", "assistant": "assistant"} + mock_prompter.chat_template_msg_variables = ["role", "content"] + mock_prompter.chat_template = "{{ messages }}" + + # Mock tokenizer + mock_tokenizer = Mock() + mock_tokenizer.pad_token_id = 0 + mock_tokenizer.eos_token_id = 2 + mock_tokenizer.bos_token_id = 1 + mock_tokenizer.eos_token = "<|endoftext|>" + mock_tokenizer.apply_chat_template = Mock(return_value=[1, 10, 20, 30, 2]) + mock_tokenizer.encode = Mock(return_value=[2]) + + return ChatTemplateStrategyWithKDv2( + prompter=mock_prompter, + tokenizer=mock_tokenizer, + train_on_inputs=False, + sequence_len=512, + logprobs_field="logprobs", + gen_temperature=1.0, + kd_temperature=1.0, + ) + + def test_v2_prepare_kd_fields_adds_target_token_ids(self, v2_strategy): + """ + Test that v2's _prepare_kd_fields hook adds target_token_ids. + + Validates the Template Method pattern fix where v2 overrides + the hook to add target_token_ids before transform. + """ + tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]} + original = {"target_token_ids": [[10, 20], [30, 40]]} + + result = v2_strategy._prepare_kd_fields(tokenized, original) + + assert "target_token_ids" in result + assert result["target_token_ids"] == [[10, 20], [30, 40]] + + def test_v2_prepare_kd_fields_handles_missing_field(self, v2_strategy): + """Test hook handles missing target_token_ids gracefully""" + tokenized = {"input_ids": [1, 10, 20, 30, 2], "labels": [1, 10, 20, 30, 2]} + original = {} + + result = v2_strategy._prepare_kd_fields(tokenized, original) + + assert "target_token_ids" not in result + + def test_v2_transform_requires_target_token_ids(self, v2_strategy): + """ + Test v2's transform fails without target_token_ids. + + Validates the bug fix - transform expects target_token_ids + to be added by the hook. + """ + sample = { + "input_ids": [1, 10, 20, 30, 2], + "labels": [1, 10, 20, 30, 2], + "logprobs": [[-0.1, -0.2], [-0.3, -0.4]], + } + + with pytest.raises(KeyError, match="target_token_ids"): + v2_strategy.transform_logprobs(sample)