diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py index 5cdc53b0c..99e17910e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py @@ -128,10 +128,10 @@ def cce_forward( if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None - # scale weight by logit_scale in-place of logits + # scale hidden_states by logit_scale in-place of logits loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight * self.logit_scale, + hidden_states[:, slice_indices, :] * self.logit_scale, + self.lm_head.weight, labels, _PATCH_OPTS, **kwargs,