Multimodal Vision Llama - rudimentary support (#1940)

---------

Co-authored-by: Sunny <sunny@Sunnys-MacBook-Air.local>
Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
Wing Lian
2024-10-02 21:02:48 -04:00
committed by GitHub
parent 844331005c
commit e1915f5625
24 changed files with 799 additions and 119 deletions

View File

@@ -28,12 +28,17 @@ from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
LlavaForConditionalGeneration,
MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -80,6 +85,9 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
model_config = model_config.text_config
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
@@ -299,11 +307,31 @@ def load_tokenizer(cfg):
return tokenizer
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
return processor
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
processor: ProcessorMixin = None, # pylint: disable=unused-argument
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
@@ -319,12 +347,23 @@ def load_model(
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(cfg)
if cfg.is_multimodal:
text_model_config = model_config.text_config
else:
text_model_config = model_config
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
if cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -461,6 +500,19 @@ def load_model(
max_memory = cfg.max_memory
device_map = cfg.device_map
AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
if cfg.is_multimodal:
if model_config.model_type == "llava":
AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif model_config.model_type == "mllama":
AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
@@ -478,7 +530,7 @@ def load_model(
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = AutoModelForCausalLM.from_config(
model_canvas = AutoModelLoader.from_config(
model_config, trust_remote_code=cfg.trust_remote_code or False
)
model_canvas.tie_weights()
@@ -633,6 +685,8 @@ def load_model(
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = load_sharded_model_quant(
base_model,
model_config,
@@ -651,7 +705,9 @@ def load_model(
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained(
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -690,13 +746,17 @@ def load_model(
and not cfg.trust_remote_code
):
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
@@ -707,21 +767,23 @@ def load_model(
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
hasattr(text_model_config, "max_seq_len")
and text_model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
model_config.max_seq_len = cfg.sequence_len
text_model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
hasattr(text_model_config, "max_sequence_length")
and text_model_config.max_sequence_length
and cfg.sequence_len > text_model_config.max_sequence_length
):
model_config.max_sequence_length = cfg.sequence_len
text_model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
@@ -734,7 +796,9 @@ def load_model(
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = AutoModelForCausalLM.from_pretrained(
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
@@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False):
from peft import LoraConfig, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
lora_target_modules = cfg.lora_target_modules or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
@@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,