try run regular CE loss on eval
This commit is contained in:
@@ -373,6 +373,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
if self.cfg.use_dynamic_finetuning:
|
||||
from axolotl.monkeypatch.loss.dft import dft_loss
|
||||
|
||||
trainer_kwargs["compute_loss_func"] = dft_loss
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
|
||||
@@ -159,10 +159,9 @@ class PatchManager:
|
||||
patch_chunked_ce_loss_fn()
|
||||
|
||||
def _apply_dft_loss_patch(self):
|
||||
if self.cfg.use_dynamic_finetuning:
|
||||
from axolotl.monkeypatch.loss.dft import patch_dft_loss_fn
|
||||
|
||||
patch_dft_loss_fn()
|
||||
# DFT loss is now applied via compute_loss_func in the trainer builder
|
||||
# See: src/axolotl/core/builders/causal.py
|
||||
pass
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
|
||||
@@ -67,10 +67,32 @@ def get_dft_loss(ignore_index: int = -100):
|
||||
return for_causal_lm_dft_loss
|
||||
|
||||
|
||||
def patch_dft_loss_fn(ignore_index: int = -100):
|
||||
"""Patch transformers to use DFT loss for causal LM"""
|
||||
import transformers.loss.loss_utils
|
||||
def dft_loss(outputs, labels, num_items_in_batch=None):
|
||||
"""DFT loss compatible with Trainer.compute_loss_func signature.
|
||||
|
||||
for_causal_lm_dft_loss = get_dft_loss(ignore_index)
|
||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_dft_loss
|
||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = for_causal_lm_dft_loss
|
||||
This function is designed to be passed to Trainer's compute_loss_func parameter.
|
||||
"""
|
||||
ignore_index = -100
|
||||
|
||||
# Shift labels for causal LM
|
||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_labels = shift_labels.to(outputs.logits.device)
|
||||
|
||||
# Create loss mask
|
||||
loss_mask = shift_labels != ignore_index
|
||||
shift_labels_masked = shift_labels.clone()
|
||||
shift_labels_masked[~loss_mask] = 0
|
||||
|
||||
# Compute log probabilities
|
||||
logprobs = selective_log_softmax(outputs.logits, shift_labels_masked)
|
||||
|
||||
# DFT loss: -exp(logprobs).detach() * logprobs
|
||||
per_token_loss = -logprobs.exp().detach() * logprobs
|
||||
|
||||
# Sum over valid tokens and normalize
|
||||
if num_items_in_batch is None:
|
||||
num_items_in_batch = loss_mask.sum()
|
||||
|
||||
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user