try run regular CE loss on eval

This commit is contained in:
Salman Mohammadi
2026-01-19 22:00:48 +00:00
parent 170dca9bb9
commit 7a4f33802d
3 changed files with 36 additions and 10 deletions

View File

@@ -373,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import dft_loss
trainer_kwargs["compute_loss_func"] = dft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(

View File

@@ -159,10 +159,9 @@ class PatchManager:
patch_chunked_ce_loss_fn()
def _apply_dft_loss_patch(self):
if self.cfg.use_dynamic_finetuning:
from axolotl.monkeypatch.loss.dft import patch_dft_loss_fn
patch_dft_loss_fn()
# DFT loss is now applied via compute_loss_func in the trainer builder
# See: src/axolotl/core/builders/causal.py
pass
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""

View File

@@ -67,10 +67,32 @@ def get_dft_loss(ignore_index: int = -100):
return for_causal_lm_dft_loss
def patch_dft_loss_fn(ignore_index: int = -100):
"""Patch transformers to use DFT loss for causal LM"""
import transformers.loss.loss_utils
def dft_loss(outputs, labels, num_items_in_batch=None):
"""DFT loss compatible with Trainer.compute_loss_func signature.
for_causal_lm_dft_loss = get_dft_loss(ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_dft_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = for_causal_lm_dft_loss
This function is designed to be passed to Trainer's compute_loss_func parameter.
"""
ignore_index = -100
# Shift labels for causal LM
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.to(outputs.logits.device)
# Create loss mask
loss_mask = shift_labels != ignore_index
shift_labels_masked = shift_labels.clone()
shift_labels_masked[~loss_mask] = 0
# Compute log probabilities
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
# DFT loss: -exp(logprobs).detach() * logprobs
per_token_loss = -logprobs.exp().detach() * logprobs
# Sum over valid tokens and normalize
if num_items_in_batch is None:
num_items_in_batch = loss_mask.sum()
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss