fix: mllama need attention fixes for fa2
This commit is contained in:
@@ -283,6 +283,11 @@ class PatchManager:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
if self.model_config.model_type == "mllama":
|
||||
from axolotl.monkeypatch.mllama_attn_patch import patch_mllama_attention
|
||||
|
||||
patch_mllama_attention()
|
||||
|
||||
def _patch_loss_llama(self):
|
||||
"""Patch loss functions and other optimizations for LLaMA models."""
|
||||
if not self.cfg.is_llama_derived_model:
|
||||
|
||||
27
src/axolotl/monkeypatch/mllama_attn_patch.py
Normal file
27
src/axolotl/monkeypatch/mllama_attn_patch.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Monkeypatch to add missing is_causal attribute to Mllama attention classes
|
||||
"""
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_mllama_attention():
|
||||
"""Add is_causal attribute to Mllama attention classes"""
|
||||
try:
|
||||
import transformers.models.mllama.modeling_mllama as mllama_modeling
|
||||
|
||||
LOG.debug("Patching Attention in mllama due to missing attributes")
|
||||
|
||||
if hasattr(mllama_modeling, "MllamaVisionAttention"):
|
||||
mllama_modeling.MllamaVisionAttention.is_causal = False
|
||||
|
||||
if hasattr(mllama_modeling, "MllamaCrossAttention"):
|
||||
mllama_modeling.MllamaCrossAttention.is_causal = False
|
||||
|
||||
if hasattr(mllama_modeling, "MllamaTextAttention"):
|
||||
mllama_modeling.MllamaTextAttention.is_causal = True
|
||||
|
||||
except ImportError:
|
||||
LOG.debug("Mllama model not available, skipping is_causal patch")
|
||||
@@ -92,9 +92,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
if "pixel_values" in batch:
|
||||
final_batch["pixel_values"] = torch.stack(batch["pixel_values"])
|
||||
|
||||
# mllama
|
||||
if "aspect_ratio_ids" in batch:
|
||||
final_batch["aspect_ratio_ids"] = torch.stack(batch["aspect_ratio_ids"])
|
||||
|
||||
if "aspect_ratio_mask" in batch:
|
||||
final_batch["aspect_ratio_mask"] = torch.stack(batch["aspect_ratio_mask"])
|
||||
|
||||
if "cross_attention_mask" in batch:
|
||||
final_batch["cross_attention_mask"] = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["cross_attention_mask"], batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# gemma3n
|
||||
if "input_features" in batch:
|
||||
final_batch["input_features"] = torch.stack(batch["input_features"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user