Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
37a66e6866 multigpu longer timeout 2025-04-09 01:54:35 -04:00
Wing Lian
9f69597a5f upgrade transformers to 4.51.1 2025-04-09 00:20:50 -04:00
3 changed files with 14 additions and 16 deletions

View File

@@ -68,7 +68,7 @@ def run_cmd(cmd: str, run_folder: str):
@app.function( @app.function(
image=cicd_image, image=cicd_image,
gpu=GPU_CONFIG, gpu=GPU_CONFIG,
timeout=60 * 60, timeout=90 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG, volumes=VOLUME_CONFIG,

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.6
packaging==23.2 packaging==23.2
peft==0.15.1 peft==0.15.1
transformers==4.51.0 transformers==4.51.1
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.6.0 accelerate==1.6.0
datasets==3.5.0 datasets==3.5.0

View File

@@ -26,7 +26,6 @@ from transformers.utils import (
) )
_PATCH_OPTS: PatchOptions | None = None _PATCH_OPTS: PatchOptions | None = None
RESET_LM_HEAD = True
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@@ -309,16 +308,7 @@ def cce_forward_multimodal(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None assert labels is not None
# TODO: check if need to handle attention_mask
# reset lm head gradient on first pass.
# linear model has some lm_head weight issue
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
global RESET_LM_HEAD # pylint: disable=global-statement
if RESET_LM_HEAD:
RESET_LM_HEAD = False
self.language_model.lm_head.weight.requires_grad_(False) # Detach
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
loss = apply_lce( loss = apply_lce(
hidden_states, hidden_states,
self.language_model.lm_head.weight, self.language_model.lm_head.weight,
@@ -383,7 +373,11 @@ def patch_llama4_text(
return maybe_model return maybe_model
modeling_llama4.Llama4ForCausalLM.forward = cce_forward setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
return None return None
@@ -409,8 +403,12 @@ def patch_llama4(
) )
return maybe_model return maybe_model
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
# patch the causal language model # patch the causal language model
modeling_llama4.Llama4ForCausalLM.forward = cce_forward setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
return None return None