Files
axolotl/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py
2025-04-28 12:18:46 -04:00

127 lines
4.1 KiB
Python

# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Cut Cross Entropy patcher"""
import transformers
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.phi3 import patch_phi3
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
patch_cohere,
patch_cohere2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma2,
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
patch_llama,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral,
patch_mistral3,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
patch_qwen2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
patch_qwen2_5_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
patch_qwen2_moe,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
patch_qwen2_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
patch_qwen3_moe,
)
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
"llama4": patch_llama4,
"llama4_text": patch_llama4_text,
"mllama": patch_mllama,
"phi3": patch_phi3,
"gemma": patch_gemma,
"gemma2": patch_gemma2,
"gemma3": patch_gemma3,
"gemma3_text": patch_gemma3_text,
"mistral": patch_mistral,
"mistral3": patch_mistral3,
"qwen2": patch_qwen2,
"qwen2_moe": patch_qwen2_moe,
"qwen2_vl": patch_qwen2_vl,
"qwen2_5_vl": patch_qwen2_5_vl,
"qwen3": patch_qwen3,
"qwen3_moe": patch_qwen3_moe,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
}
def cce_patch(
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
reduction: str = "mean",
filter_eps: float | str | None = "auto",
accum_e_fp32: bool = False,
accum_c_fp32: bool = False,
filter_e_grad: bool = True,
filter_c_grad: bool = True,
train_only: bool = False,
) -> TransformersModelT | None:
if isinstance(impl, LinearCrossEntropyImpl):
impl = impl.name.lower()
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
raise ValueError(f"Unknown {impl=}")
if isinstance(model_type_or_model, transformers.PreTrainedModel):
if hasattr(model_type_or_model, "config"):
model_type = getattr(
getattr(model_type_or_model, "config", None), "model_type", None
)
else:
raise ValueError(
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
)
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
model_type = model_type_or_model.model_type
else:
model_type = model_type_or_model
patch_options = PatchOptions(
impl=impl,
reduction=reduction,
filter_eps=filter_eps,
accum_e_fp32=accum_e_fp32,
accum_c_fp32=accum_c_fp32,
filter_e_grad=filter_e_grad,
filter_c_grad=filter_c_grad,
train_only=train_only,
)
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
model_type_or_model, patch_options
)
raise RuntimeError(f"Unknown model type {model_type}")