Multimodal Vision Llama - rudimentary support (#1940)
--------- Co-authored-by: Sunny <sunny@Sunnys-MacBook-Air.local> Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
@@ -28,12 +28,17 @@ from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
LlavaForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
@@ -80,6 +85,9 @@ def get_module_class_from_name(module, name):
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
if cfg.is_multimodal:
|
||||
model_config = model_config.text_config
|
||||
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
and model_config.quantization_config
|
||||
@@ -299,11 +307,31 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
|
||||
|
||||
processor_cls = AutoProcessor
|
||||
if cfg.processor_type:
|
||||
processor_cls = getattr(transformers, cfg.processor_type)
|
||||
|
||||
processor = processor_cls.from_pretrained(
|
||||
cfg.processor_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
processor: ProcessorMixin = None, # pylint: disable=unused-argument
|
||||
inference: bool = False,
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
@@ -319,12 +347,23 @@ def load_model(
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(cfg)
|
||||
|
||||
if cfg.is_multimodal:
|
||||
text_model_config = model_config.text_config
|
||||
else:
|
||||
text_model_config = model_config
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
|
||||
if cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
|
||||
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.attention.mllama import patch_mllama
|
||||
|
||||
patch_mllama()
|
||||
|
||||
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||
@@ -461,6 +500,19 @@ def load_model(
|
||||
max_memory = cfg.max_memory
|
||||
device_map = cfg.device_map
|
||||
|
||||
AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
if cfg.is_multimodal:
|
||||
if model_config.model_type == "llava":
|
||||
AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
LlavaForConditionalGeneration
|
||||
)
|
||||
elif model_config.model_type == "mllama":
|
||||
AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
MllamaForConditionalGeneration
|
||||
)
|
||||
else:
|
||||
AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||
|
||||
if cfg.gpu_memory_limit:
|
||||
gpu_memory_limit = (
|
||||
str(cfg.gpu_memory_limit) + "GiB"
|
||||
@@ -478,7 +530,7 @@ def load_model(
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = AutoModelForCausalLM.from_config(
|
||||
model_canvas = AutoModelLoader.from_config(
|
||||
model_config, trust_remote_code=cfg.trust_remote_code or False
|
||||
)
|
||||
model_canvas.tie_weights()
|
||||
@@ -633,6 +685,8 @@ def load_model(
|
||||
quantization_config = (
|
||||
quantization_config or model_kwargs["quantization_config"]
|
||||
)
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = load_sharded_model_quant(
|
||||
base_model,
|
||||
model_config,
|
||||
@@ -651,7 +705,9 @@ def load_model(
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = AutoModelLoader.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
**model_kwargs,
|
||||
@@ -690,13 +746,17 @@ def load_model(
|
||||
and not cfg.trust_remote_code
|
||||
):
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = AutoModelLoader.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
@@ -707,21 +767,23 @@ def load_model(
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(model_config, "max_seq_len")
|
||||
and model_config.max_seq_len
|
||||
hasattr(text_model_config, "max_seq_len")
|
||||
and text_model_config.max_seq_len
|
||||
and cfg.sequence_len > model_config.max_seq_len
|
||||
):
|
||||
model_config.max_seq_len = cfg.sequence_len
|
||||
text_model_config.max_seq_len = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(model_config, "max_sequence_length")
|
||||
and model_config.max_sequence_length
|
||||
and cfg.sequence_len > model_config.max_sequence_length
|
||||
hasattr(text_model_config, "max_sequence_length")
|
||||
and text_model_config.max_sequence_length
|
||||
and cfg.sequence_len > text_model_config.max_sequence_length
|
||||
):
|
||||
model_config.max_sequence_length = cfg.sequence_len
|
||||
text_model_config.max_sequence_length = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = AutoModelLoader.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -734,7 +796,9 @@ def load_model(
|
||||
if "device_map" in model_kwargs:
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
if cfg.is_multimodal:
|
||||
model_config.text_config = text_model_config
|
||||
model = AutoModelLoader.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -1016,12 +1080,17 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
lora_target_modules_as_list = (
|
||||
lora_target_modules
|
||||
if isinstance(lora_target_modules, list)
|
||||
else [lora_target_modules]
|
||||
)
|
||||
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||
|
||||
lora_config_kwargs = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
@@ -1040,6 +1109,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
|
||||
Reference in New Issue
Block a user