Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
ebe5abad53 0.8.1 version
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-04-07 20:49:40 -04:00
2 changed files with 13 additions and 15 deletions

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0"
__version__ = "0.8.1"

View File

@@ -26,7 +26,6 @@ from transformers.utils import (
)
_PATCH_OPTS: PatchOptions | None = None
RESET_LM_HEAD = True
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@@ -309,16 +308,7 @@ def cce_forward_multimodal(
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# reset lm head gradient on first pass.
# linear model has some lm_head weight issue
# see https://github.com/axolotl-ai-cloud/axolotl/pull/2505
global RESET_LM_HEAD # pylint: disable=global-statement
if RESET_LM_HEAD:
RESET_LM_HEAD = False
self.language_model.lm_head.weight.requires_grad_(False) # Detach
self.language_model.lm_head.weight.requires_grad_(True) # Reattach
# TODO: check if need to handle attention_mask
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
@@ -383,7 +373,11 @@ def patch_llama4_text(
return maybe_model
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
setattr(
modeling_llama4.Llama4ForCausalLM,
"forward",
cce_forward,
)
return None
@@ -409,8 +403,12 @@ def patch_llama4(
)
return maybe_model
modeling_llama4.Llama4ForConditionalGeneration.forward = cce_forward_multimodal
setattr(
modeling_llama4.Llama4ForConditionalGeneration,
"forward",
cce_forward_multimodal,
)
# patch the causal language model
modeling_llama4.Llama4ForCausalLM.forward = cce_forward
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
return None