fix: mllama need attention fixes for fa2

This commit is contained in:
NanoCode012
2025-07-22 12:11:30 +07:00
parent ee19007ba4
commit a0bfdd1777
3 changed files with 42 additions and 0 deletions

View File

@@ -283,6 +283,11 @@ class PatchManager:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type == "mllama":
from axolotl.monkeypatch.mllama_attn_patch import patch_mllama_attention
patch_mllama_attention()
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:

View File

@@ -0,0 +1,27 @@
"""
Monkeypatch to add missing is_causal attribute to Mllama attention classes
"""
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_mllama_attention():
"""Add is_causal attribute to Mllama attention classes"""
try:
import transformers.models.mllama.modeling_mllama as mllama_modeling
LOG.debug("Patching Attention in mllama due to missing attributes")
if hasattr(mllama_modeling, "MllamaVisionAttention"):
mllama_modeling.MllamaVisionAttention.is_causal = False
if hasattr(mllama_modeling, "MllamaCrossAttention"):
mllama_modeling.MllamaCrossAttention.is_causal = False
if hasattr(mllama_modeling, "MllamaTextAttention"):
mllama_modeling.MllamaTextAttention.is_causal = True
except ImportError:
LOG.debug("Mllama model not available, skipping is_causal patch")

View File

@@ -92,9 +92,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
if "pixel_values" in batch:
final_batch["pixel_values"] = torch.stack(batch["pixel_values"])
# mllama
if "aspect_ratio_ids" in batch:
final_batch["aspect_ratio_ids"] = torch.stack(batch["aspect_ratio_ids"])
if "aspect_ratio_mask" in batch:
final_batch["aspect_ratio_mask"] = torch.stack(batch["aspect_ratio_mask"])
if "cross_attention_mask" in batch:
final_batch["cross_attention_mask"] = torch.nn.utils.rnn.pad_sequence(
batch["cross_attention_mask"], batch_first=True, padding_value=0
)
# gemma3n
if "input_features" in batch:
final_batch["input_features"] = torch.stack(batch["input_features"])