From 4581d6a8de336188df562b62cf32244ccdb2fbe8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 9 Apr 2025 13:45:29 +0700 Subject: [PATCH] fix: accidentally reassigning tensor to weight --- .../integrations/cut_cross_entropy/monkeypatch/llama4.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index 70fcf295d..f9b77c1db 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -310,13 +310,14 @@ def cce_forward_multimodal( if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None - # reset lm head if not already done. linear model has some lm_head weight issue + # reset lm head gradient on first pass. + # linear model has some lm_head weight issue + # see https://github.com/axolotl-ai-cloud/axolotl/pull/2505 global RESET_LM_HEAD # pylint: disable=global-statement if RESET_LM_HEAD: RESET_LM_HEAD = False - self.language_model.lm_head.weight = ( - self.language_model.lm_head.weight.detach().requires_grad_(True) - ) + self.language_model.lm_head.weight.requires_grad_(False) # Detach + self.language_model.lm_head.weight.requires_grad_(True) # Reattach loss = apply_lce( hidden_states,