Multimodal Vision Llama - rudimentary support (#1940)

---------

Co-authored-by: Sunny <sunny@Sunnys-MacBook-Air.local>
Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
Wing Lian
2024-10-02 21:02:48 -04:00
committed by GitHub
parent 844331005c
commit e1915f5625
24 changed files with 799 additions and 119 deletions

View File

@@ -121,15 +121,36 @@ def normalize_config(cfg):
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
"pixtral",
]
)
or cfg.is_multimodal
)
if cfg.is_multimodal:
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
)
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())