* fix: update chat_template * fix: handle gemma3 showing a lot of no content for turn 0 * fix: remove unknown config from examples * fix: test * fix: temporary disable gemma2 test * fix: stop overwriting config.text_config unnecessarily * fix: handling of set cache to the text_config section * feat: add liger gemma support and bump liger to 0.5.5 * fix: add double use_cache setting * fix: add support for final_logit_softcap in CCE for gemma2/3 * fix: set use_cache before model load * feat: add missing layernorm override * fix: handle gemma3 rmsnorm * fix: use wrapper to pass dim as hidden_size * fix: change dim to positional * fix: patch with wrong mlp * chore: refactor use_cache handling * fix import issues * fix tests.e2e.utils import --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
41 lines
1.0 KiB
Python
41 lines
1.0 KiB
Python
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
|
|
"""Monkeypatch for apply_lce to add softcap."""
|
|
|
|
import torch
|
|
from cut_cross_entropy import linear_cross_entropy
|
|
from cut_cross_entropy.transformers.utils import PatchOptions
|
|
|
|
|
|
def apply_lce(
|
|
e: torch.Tensor,
|
|
c: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
opts: PatchOptions,
|
|
bias: torch.Tensor | None = None,
|
|
softcap: float | None = None,
|
|
**loss_kwargs,
|
|
) -> torch.Tensor:
|
|
"""Monkey patch for apply_lce to support softcap kwarg."""
|
|
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
|
cce_kwargs = opts.to_kwargs()
|
|
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
|
cce_kwargs["reduction"] = "sum"
|
|
else:
|
|
num_items_in_batch = None
|
|
|
|
loss = linear_cross_entropy(
|
|
e,
|
|
c,
|
|
labels.to(e.device),
|
|
bias=bias,
|
|
shift=True,
|
|
softcap=softcap,
|
|
**cce_kwargs,
|
|
)
|
|
|
|
if num_items_in_batch is not None:
|
|
loss = loss / num_items_in_batch
|
|
|
|
return loss
|