EAFT (#3366) [skip ci]

* wip eaft

* fix eaft loss fn

* adding ref

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
This commit is contained in:
salman
2026-01-28 11:44:15 +00:00
committed by GitHub
parent fc4e37920b
commit dd9ebaeba1
4 changed files with 158 additions and 0 deletions

View File

@@ -373,6 +373,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(

View File

@@ -0,0 +1,51 @@
"""
eaft (entropy-aware focal training) loss implementation
weights examples by entropy approximation from top-k logits
Reference: https://github.com/ymxyll/LlamaFactory-EAFT/blob/e2ce19e8efcc226450ee8f2b81dfe4e69f1f945d/src/llamafactory/train/trainer_utils.py
"""
import torch
import torch.nn.functional as F
def eaft_loss(outputs, labels, num_items_in_batch=None, alpha=1.0, k=20):
"""
compute eaft loss with entropy weighting
args:
outputs: model outputs containing logits
labels: target labels for computing loss
num_items_in_batch: for sample packing support
alpha: exponent for entropy weighting (default 1.0)
k: number of top logits for entropy approximation (default 20)
"""
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
vocab_size = shift_logits.size(-1)
shift_logits_view = shift_logits.view(-1, vocab_size)
shift_labels_view = shift_labels.view(-1)
mask = shift_labels_view != -100
with torch.no_grad():
top_k_logits, _ = torch.topk(
shift_logits_view[mask].float(), k=min(k, vocab_size), dim=-1
)
top_k_probs = F.softmax(top_k_logits, dim=-1)
entropy = -(top_k_probs * torch.log(top_k_probs + 1e-10)).sum(dim=-1)
weights = torch.pow(entropy, alpha)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
per_token_loss = loss_fct(shift_logits_view[mask], shift_labels_view[mask])
weighted_loss = per_token_loss * weights
if num_items_in_batch is not None:
loss = weighted_loss.sum() / num_items_in_batch
else:
loss = weighted_loss.mean()
return loss

View File

@@ -676,6 +676,24 @@ class AxolotlInputConfig(
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
use_eaft: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable Entropy-Aware Focal Training loss (EAFT)"
},
)
eaft_alpha: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Exponent for entropy weighting in EAFT (default: 1.0)"
},
)
eaft_k: int | None = Field(
default=20,
json_schema_extra={
"description": "Number of top logits for entropy approximation (default: 20)"
},
)
tiled_mlp: bool | None = Field(
default=None,