Compare commits
2 Commits
fix/doc-ke
...
fix/cce-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4581d6a8de | ||
|
|
1a85fab2ca |
@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||||
{"type": "text", "text": "Describe this image in detail."}
|
{"type": "text", "text": "Describe this image in detail."}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
RESET_LM_HEAD = True
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||||
@@ -308,7 +309,16 @@ def cce_forward_multimodal(
|
|||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
assert labels is not None
|
assert labels is not None
|
||||||
# TODO: check if need to handle attention_mask
|
|
||||||
|
# reset lm head gradient on first pass.
|
||||||
|
# linear model has some lm_head weight issue
|
||||||
|
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
|
||||||
|
global RESET_LM_HEAD # pylint: disable=global-statement
|
||||||
|
if RESET_LM_HEAD:
|
||||||
|
RESET_LM_HEAD = False
|
||||||
|
self.language_model.lm_head.weight.requires_grad_(False) # Detach
|
||||||
|
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
|
||||||
|
|
||||||
loss = apply_lce(
|
loss = apply_lce(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.language_model.lm_head.weight,
|
self.language_model.lm_head.weight,
|
||||||
@@ -373,11 +383,7 @@ def patch_llama4_text(
|
|||||||
|
|
||||||
return maybe_model
|
return maybe_model
|
||||||
|
|
||||||
setattr(
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||||
modeling_llama4.Llama4ForCausalLM,
|
|
||||||
"forward",
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -403,12 +409,8 @@ def patch_llama4(
|
|||||||
)
|
)
|
||||||
return maybe_model
|
return maybe_model
|
||||||
|
|
||||||
setattr(
|
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
modeling_llama4.Llama4ForConditionalGeneration,
|
|
||||||
"forward",
|
|
||||||
cce_forward_multimodal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# patch the causal language model
|
# patch the causal language model
|
||||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ liger_fused_linear_cross_entropy: true
|
|||||||
- granite
|
- granite
|
||||||
- jamba
|
- jamba
|
||||||
- llama
|
- llama
|
||||||
- llama4 (partial support, no support for FLCE yet)
|
|
||||||
- mistral
|
- mistral
|
||||||
- mixtral
|
- mixtral
|
||||||
- mllama
|
- mllama
|
||||||
|
|||||||
Reference in New Issue
Block a user