diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3da5cc0dd..613bf9b21 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -30,6 +30,7 @@ from transformers import ( # noqa: F401 AddedToken, AutoConfig, AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, @@ -553,6 +554,10 @@ class ModelLoader: self.AutoModelLoader = ( # pylint: disable=invalid-name MllamaForConditionalGeneration ) + elif self.model_config.model_type == "qwen2_vl": + self.AutoModelLoader = ( # pylint: disable=invalid-name + AutoModelForImageTextToText + ) else: self.AutoModelLoader = ( AutoModelForVision2Seq # pylint: disable=invalid-name