Compare commits
2 Commits
transforme
...
fix/cce-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4581d6a8de | ||
|
|
1a85fab2ca |
@@ -68,7 +68,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=90 * 60,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
|
||||
@@ -12,7 +12,7 @@ liger-kernel==0.5.6
|
||||
packaging==23.2
|
||||
|
||||
peft==0.15.1
|
||||
transformers==4.51.1
|
||||
transformers==4.51.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers.utils import (
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
RESET_LM_HEAD = True
|
||||
|
||||
|
||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
||||
@@ -308,7 +309,16 @@ def cce_forward_multimodal(
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# TODO: check if need to handle attention_mask
|
||||
|
||||
# reset lm head gradient on first pass.
|
||||
# linear model has some lm_head weight issue
|
||||
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
|
||||
global RESET_LM_HEAD # pylint: disable=global-statement
|
||||
if RESET_LM_HEAD:
|
||||
RESET_LM_HEAD = False
|
||||
self.language_model.lm_head.weight.requires_grad_(False) # Detach
|
||||
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
|
||||
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
@@ -373,11 +383,7 @@ def patch_llama4_text(
|
||||
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForCausalLM,
|
||||
"forward",
|
||||
cce_forward,
|
||||
)
|
||||
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
@@ -403,12 +409,8 @@ def patch_llama4(
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForConditionalGeneration,
|
||||
"forward",
|
||||
cce_forward_multimodal,
|
||||
)
|
||||
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
||||
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user