Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
4581d6a8de fix: accidentally reassigning tensor to weight 2025-04-09 13:45:29 +07:00
NanoCode012
1a85fab2ca fix: lm_head is a view or related view modified 2025-04-08 17:32:28 +07:00
3 changed files with 15 additions and 14 deletions

View File

@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
{
"role": "user",
"content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},

View File

@@ -26,6 +26,7 @@ from transformers.utils import (
)
_PATCH_OPTS: PatchOptions | None = None
RESET_LM_HEAD = True
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@@ -308,7 +309,16 @@ def cce_forward_multimodal(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# TODO: check if need to handle attention_mask
# reset lm head gradient on first pass.
# linear model has some lm_head weight issue
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
global RESET_LM_HEAD # pylint: disable=global-statement
if RESET_LM_HEAD:
RESET_LM_HEAD = False
self.language_model.lm_head.weight.requires_grad_(False) # Detach
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
@@ -373,11 +383,7 @@ def patch_llama4_text(
return maybe_model
setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
return None
@@ -403,12 +409,8 @@ def patch_llama4(
)
return maybe_model
setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal language model
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
return None

View File

@@ -29,7 +29,6 @@ liger_fused_linear_cross_entropy: true
- granite
- jamba
- llama
- llama4 (partial support, no support for FLCE yet)
- mistral
- mixtral
- mllama