From 7a4f33802dae943c718df1fea63e6661ad73d1c3 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi <“salman.mohammadi@outlook.com”> Date: Mon, 19 Jan 2026 22:00:48 +0000 Subject: [PATCH] try run regular CE loss on eval --- src/axolotl/core/builders/causal.py | 5 ++++ src/axolotl/loaders/patch_manager.py | 7 +++--- src/axolotl/monkeypatch/loss/dft.py | 34 +++++++++++++++++++++++----- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index cda98087f..b3545f1b2 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -373,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = multiple + if self.cfg.use_dynamic_finetuning: + from axolotl.monkeypatch.loss.dft import dft_loss + + trainer_kwargs["compute_loss_func"] = dft_loss + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index d917fd5a6..d89d32f11 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -159,10 +159,9 @@ class PatchManager: patch_chunked_ce_loss_fn() def _apply_dft_loss_patch(self): - if self.cfg.use_dynamic_finetuning: - from axolotl.monkeypatch.loss.dft import patch_dft_loss_fn - - patch_dft_loss_fn() + # DFT loss is now applied via compute_loss_func in the trainer builder + # See: src/axolotl/core/builders/causal.py + pass def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" diff --git a/src/axolotl/monkeypatch/loss/dft.py b/src/axolotl/monkeypatch/loss/dft.py index b16144427..652ca6abb 100644 --- a/src/axolotl/monkeypatch/loss/dft.py +++ b/src/axolotl/monkeypatch/loss/dft.py @@ -67,10 +67,32 @@ def get_dft_loss(ignore_index: int = -100): return for_causal_lm_dft_loss -def patch_dft_loss_fn(ignore_index: int = -100): - """Patch transformers to use DFT loss for causal LM""" - import transformers.loss.loss_utils +def dft_loss(outputs, labels, num_items_in_batch=None): + """DFT loss compatible with Trainer.compute_loss_func signature. - for_causal_lm_dft_loss = get_dft_loss(ignore_index) - transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_dft_loss - transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = for_causal_lm_dft_loss + This function is designed to be passed to Trainer's compute_loss_func parameter. + """ + ignore_index = -100 + + # Shift labels for causal LM + labels = F.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + shift_labels = shift_labels.to(outputs.logits.device) + + # Create loss mask + loss_mask = shift_labels != ignore_index + shift_labels_masked = shift_labels.clone() + shift_labels_masked[~loss_mask] = 0 + + # Compute log probabilities + logprobs = selective_log_softmax(outputs.logits, shift_labels_masked) + + # DFT loss: -exp(logprobs).detach() * logprobs + per_token_loss = -logprobs.exp().detach() * logprobs + + # Sum over valid tokens and normalize + if num_items_in_batch is None: + num_items_in_batch = loss_mask.sum() + + loss = (per_token_loss * loss_mask).sum() / num_items_in_batch + return loss