|
|
|
|
@@ -26,7 +26,6 @@ from transformers.utils import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_PATCH_OPTS: PatchOptions | None = None
|
|
|
|
|
RESET_LM_HEAD = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
|
|
|
|
@@ -309,16 +308,7 @@ def cce_forward_multimodal(
|
|
|
|
|
|
|
|
|
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
|
|
|
assert labels is not None
|
|
|
|
|
|
|
|
|
|
# reset lm head gradient on first pass.
|
|
|
|
|
# linear model has some lm_head weight issue
|
|
|
|
|
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
|
|
|
|
|
global RESET_LM_HEAD # pylint: disable=global-statement
|
|
|
|
|
if RESET_LM_HEAD:
|
|
|
|
|
RESET_LM_HEAD = False
|
|
|
|
|
self.language_model.lm_head.weight.requires_grad_(False) # Detach
|
|
|
|
|
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
|
|
|
|
|
|
|
|
|
|
# TODO: check if need to handle attention_mask
|
|
|
|
|
loss = apply_lce(
|
|
|
|
|
hidden_states,
|
|
|
|
|
self.language_model.lm_head.weight,
|
|
|
|
|
@@ -383,7 +373,11 @@ def patch_llama4_text(
|
|
|
|
|
|
|
|
|
|
return maybe_model
|
|
|
|
|
|
|
|
|
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
|
|
|
|
setattr(
|
|
|
|
|
modeling_llama4.Llama4ForCausalLM,
|
|
|
|
|
"forward",
|
|
|
|
|
cce_forward,
|
|
|
|
|
)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -409,8 +403,12 @@ def patch_llama4(
|
|
|
|
|
)
|
|
|
|
|
return maybe_model
|
|
|
|
|
|
|
|
|
|
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
|
|
|
|
|
setattr(
|
|
|
|
|
modeling_llama4.Llama4ForConditionalGeneration,
|
|
|
|
|
"forward",
|
|
|
|
|
cce_forward_multimodal,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# patch the causal language model
|
|
|
|
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
|
|
|
|
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
|
|
|
|
return None
|
|
|
|
|
|