fix: minor patches for multimodal (#2441)
* fix: update chat_template * fix: handle gemma3 showing a lot of no content for turn 0 * fix: remove unknown config from examples * fix: test * fix: temporary disable gemma2 test * fix: stop overwriting config.text_config unnecessarily * fix: handling of set cache to the text_config section * feat: add liger gemma support and bump liger to 0.5.5 * fix: add double use_cache setting * fix: add support for final_logit_softcap in CCE for gemma2/3 * fix: set use_cache before model load * feat: add missing layernorm override * fix: handle gemma3 rmsnorm * fix: use wrapper to pass dim as hidden_size * fix: change dim to positional * fix: patch with wrong mlp * chore: refactor use_cache handling * fix import issues * fix tests.e2e.utils import --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -25,8 +25,8 @@ import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.distributed import zero_only
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
@@ -15,7 +15,6 @@ import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
@@ -33,6 +32,8 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@@ -134,25 +135,17 @@ def cce_forward(
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logger.warning_once(
|
||||
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
|
||||
)
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logger.warning_once(
|
||||
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
@@ -353,6 +346,7 @@ def cce_forward_multimodal(
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
@@ -20,6 +20,26 @@ liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- deepseek_v2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3 (partial support, no support for FLCE yet)
|
||||
- granite
|
||||
- jamba
|
||||
- llama
|
||||
- mistral
|
||||
- mixtral
|
||||
- mllama
|
||||
- mllama_text_model
|
||||
- olmo2
|
||||
- paligemma
|
||||
- phi3
|
||||
- qwen2
|
||||
- qwen2_5_vl
|
||||
- qwen2_vl
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
|
||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
|
||||
def pre_model_load(self, cfg):
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||
raise ValueError(
|
||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||
)
|
||||
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
@@ -104,15 +114,51 @@ class LigerPlugin(BasePlugin):
|
||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_glu_activation:
|
||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
|
||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
|
||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||
|
||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||
_liger_rms_norm_wrapper,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
raise NotImplementedError(
|
||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||
)
|
||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||
|
||||
@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if turn_idx >= len(turns):
|
||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||
|
||||
# mistral does not output message if it contains only system message
|
||||
# mistral/gemma3 does not output message if it contains only system message
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and "mistral" in self.tokenizer.name_or_path.lower()
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
# gemma3 uses gemma tokenizer
|
||||
or "gemma" in self.tokenizer.name_or_path.lower()
|
||||
)
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import math
|
||||
import os
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
@@ -25,7 +25,7 @@ from peft import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import ( # noqa: F401
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -39,6 +39,7 @@ from transformers import ( # noqa: F401
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
@@ -107,14 +108,21 @@ def get_module_class_from_name(module, name):
|
||||
return None
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
# Set use_cache to False
|
||||
if hasattr(model_config, "use_cache"):
|
||||
model_config.use_cache = False
|
||||
|
||||
if cfg.is_multimodal:
|
||||
if hasattr(model_config, "text_config"):
|
||||
model_config = model_config.text_config
|
||||
model_config.use_cache = False
|
||||
elif hasattr(model_config, "get_text_config"):
|
||||
model_config = model_config.get_text_config()
|
||||
model_config.use_cache = False
|
||||
# For multimodal configs, use_cache is set in the text_config
|
||||
if hasattr(model_config, "get_text_config"):
|
||||
text_config = model_config.get_text_config()
|
||||
if hasattr(text_config, "use_cache"):
|
||||
text_config.use_cache = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"No text config found for multimodal model. Please raise an Issue with model details."
|
||||
)
|
||||
|
||||
# check if image_size is not set and load image size from model config if available
|
||||
if (
|
||||
@@ -523,14 +531,6 @@ class ModelLoader:
|
||||
|
||||
# init model config
|
||||
self.model_config = load_model_config(cfg)
|
||||
if cfg.is_multimodal:
|
||||
if hasattr(self.model_config, "text_config"):
|
||||
self.text_model_config = self.model_config.text_config
|
||||
else:
|
||||
# for qwen2_vl
|
||||
self.text_model_config = self.model_config.get_text_config()
|
||||
else:
|
||||
self.text_model_config = self.model_config
|
||||
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
@@ -947,8 +947,6 @@ class ModelLoader:
|
||||
quantization_config = (
|
||||
quantization_config or self.model_kwargs["quantization_config"]
|
||||
)
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = load_sharded_model_quant(
|
||||
self.base_model,
|
||||
self.model_config,
|
||||
@@ -969,9 +967,6 @@ class ModelLoader:
|
||||
|
||||
_ = _configure_zero3_memory_efficient_loading()
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
|
||||
# Load model with random initialization if specified
|
||||
if self.cfg.random_init_weights:
|
||||
# AutoModel classes support the from_config method
|
||||
@@ -1026,8 +1021,6 @@ class ModelLoader:
|
||||
and self.model_type != "AutoModelForCausalLM"
|
||||
and not self.cfg.trust_remote_code
|
||||
):
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
if self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
@@ -1043,25 +1036,7 @@ class ModelLoader:
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(self.text_model_config, "max_seq_len")
|
||||
and self.text_model_config.max_seq_len
|
||||
and self.cfg.sequence_len > self.text_model_config.max_seq_len
|
||||
):
|
||||
self.text_model_config.max_seq_len = self.cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(self.text_model_config, "max_sequence_length")
|
||||
and self.text_model_config.max_sequence_length
|
||||
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
|
||||
):
|
||||
self.text_model_config.max_sequence_length = self.cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||
if self.cfg.gptq:
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
@@ -1080,8 +1055,6 @@ class ModelLoader:
|
||||
|
||||
_ = _configure_zero3_memory_efficient_loading()
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
@@ -1346,8 +1319,6 @@ class ModelLoader:
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
LOG.warning("there are no parameters that require gradient updates")
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
if self.cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
Reference in New Issue
Block a user