This commit is contained in:
bursteratom
2024-12-06 14:37:32 -05:00
parent a3a4d22709
commit eab1638686

View File

@@ -30,6 +30,7 @@ from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
@@ -553,6 +554,10 @@ class ModelLoader:
self.AutoModelLoader = ( # pylint: disable=invalid-name self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration MllamaForConditionalGeneration
) )
elif self.model_config.model_type == "qwen2_vl":
self.AutoModelLoader = ( # pylint: disable=invalid-name
AutoModelForImageTextToText
)
else: else:
self.AutoModelLoader = ( self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name AutoModelForVision2Seq # pylint: disable=invalid-name