diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index 70fcf295d..f9b77c1db 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -310,13 +310,14 @@ def cce_forward_multimodal( if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None - # reset lm head if not already done. linear model has some lm_head weight issue + # reset lm head gradient on first pass. + # linear model has some lm_head weight issue + # see https://github.com/axolotl-ai-cloud/axolotl/pull/2505 global RESET_LM_HEAD # pylint: disable=global-statement if RESET_LM_HEAD: RESET_LM_HEAD = False - self.language_model.lm_head.weight = ( - self.language_model.lm_head.weight.detach().requires_grad_(True) - ) + self.language_model.lm_head.weight.requires_grad_(False) # Detach + self.language_model.lm_head.weight.requires_grad_(True) # Reattach loss = apply_lce( hidden_states,