diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index f08663f99..70fcf295d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -26,6 +26,7 @@ from transformers.utils import ( ) _PATCH_OPTS: PatchOptions | None = None +RESET_LM_HEAD = True @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) @@ -308,7 +309,15 @@ def cce_forward_multimodal( if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None - # TODO: check if need to handle attention_mask + + # reset lm head if not already done. linear model has some lm_head weight issue + global RESET_LM_HEAD # pylint: disable=global-statement + if RESET_LM_HEAD: + RESET_LM_HEAD = False + self.language_model.lm_head.weight = ( + self.language_model.lm_head.weight.detach().requires_grad_(True) + ) + loss = apply_lce( hidden_states, self.language_model.lm_head.weight, @@ -373,11 +382,7 @@ def patch_llama4_text( return maybe_model - setattr( - modeling_llama4.Llama4ForCausalLM, - "forward", - cce_forward, - ) + modeling_llama4.Llama4ForCausalLM.forward = cce_forward return None @@ -403,12 +408,8 @@ def patch_llama4( ) return maybe_model - setattr( - modeling_llama4.Llama4ForConditionalGeneration, - "forward", - cce_forward_multimodal, - ) + modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal # patch the causal language model - setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward) + modeling_llama4.Llama4ForCausalLM.forward = cce_forward return None