diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 48ee78cbc..7e9dd955c 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -283,6 +283,11 @@ class PatchManager: replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + if self.model_config.model_type == "mllama": + from axolotl.monkeypatch.mllama_attn_patch import patch_mllama_attention + + patch_mllama_attention() + def _patch_loss_llama(self): """Patch loss functions and other optimizations for LLaMA models.""" if not self.cfg.is_llama_derived_model: diff --git a/src/axolotl/monkeypatch/mllama_attn_patch.py b/src/axolotl/monkeypatch/mllama_attn_patch.py new file mode 100644 index 000000000..8acf4231d --- /dev/null +++ b/src/axolotl/monkeypatch/mllama_attn_patch.py @@ -0,0 +1,27 @@ +""" +Monkeypatch to add missing is_causal attribute to Mllama attention classes +""" + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def patch_mllama_attention(): + """Add is_causal attribute to Mllama attention classes""" + try: + import transformers.models.mllama.modeling_mllama as mllama_modeling + + LOG.debug("Patching Attention in mllama due to missing attributes") + + if hasattr(mllama_modeling, "MllamaVisionAttention"): + mllama_modeling.MllamaVisionAttention.is_causal = False + + if hasattr(mllama_modeling, "MllamaCrossAttention"): + mllama_modeling.MllamaCrossAttention.is_causal = False + + if hasattr(mllama_modeling, "MllamaTextAttention"): + mllama_modeling.MllamaTextAttention.is_causal = True + + except ImportError: + LOG.debug("Mllama model not available, skipping is_causal patch") diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index c5f6a62e3..1003706cc 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -92,9 +92,19 @@ class MultiModalChatDataCollator(DataCollatorMixin): if "pixel_values" in batch: final_batch["pixel_values"] = torch.stack(batch["pixel_values"]) + # mllama if "aspect_ratio_ids" in batch: final_batch["aspect_ratio_ids"] = torch.stack(batch["aspect_ratio_ids"]) + if "aspect_ratio_mask" in batch: + final_batch["aspect_ratio_mask"] = torch.stack(batch["aspect_ratio_mask"]) + + if "cross_attention_mask" in batch: + final_batch["cross_attention_mask"] = torch.nn.utils.rnn.pad_sequence( + batch["cross_attention_mask"], batch_first=True, padding_value=0 + ) + + # gemma3n if "input_features" in batch: final_batch["input_features"] = torch.stack(batch["input_features"])