diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 00a1a0c67..80cbb0f82 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -40,6 +40,14 @@ try: except ImportError: pass +try: + from llava.train.llava_trainer import get_mm_adapter_state_maybe_zero_3 +except ImportError: + + def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + raise ImportError("missing LLaVA package") + + LOG = logging.getLogger("axolotl.core.trainer_builder") @@ -243,6 +251,36 @@ class AxolotlTrainer(Trainer): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) + def _save_checkpoint(self, model, trial, metrics=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + weight_to_save = get_mm_adapter_state_maybe_zero_3( + self.model.named_parameters(), keys_to_match + ) + + if self.args.local_rank in (0, -1): + self.model.config.save_pretrained(output_dir) + torch.save(weight_to_save, os.path.join(output_dir, "mm_projector.bin")) + else: + super()._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, "tune_mm_mlp_adapter", False): + pass + else: + super()._save(output_dir, state_dict) + class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ @@ -628,18 +666,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") - data_collator_kwargs = { - "padding": True, # True/"longest" is the default - } - if self.cfg.pad_to_sequence_len: - data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( - self.cfg.sequence_len / 64 - ) - else: - # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check - # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html - data_collator_kwargs["pad_to_multiple_of"] = 64 - if self.cfg.is_llama_derived_model and self.cfg.landmark_attention: from axolotl.monkeypatch.llama_landmark_attn import ( add_mem_tokens, @@ -664,22 +690,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls ) + trainer_collator_kwargs = self.build_data_collator() + trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - data_collator=DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), - bench_data_collator=transformers.DataCollatorForSeq2Seq( - self.tokenizer, - return_tensors="pt", - **data_collator_kwargs, - ), callbacks=self.get_callbacks(), + **trainer_collator_kwargs, **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) @@ -687,3 +706,41 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer.add_callback(callback) return trainer + + def build_data_collator(self): + data_collator_kwargs = { + "padding": True, # True/"longest" is the default + } + if self.cfg.pad_to_sequence_len: + data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( + self.cfg.sequence_len / 64 + ) + else: + # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html + data_collator_kwargs["pad_to_multiple_of"] = 64 + + collator_kwargs = {} + if self.cfg.multimodal: + from llava.train.train import DataCollatorForSupervisedDataset + + collator_kwargs["data_collator"] = DataCollatorForSupervisedDataset( + tokenizer=self.tokenizer, + ) + else: + collator_kwargs["data_collator"] = DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ) + + if self.cfg.do_bench_eval: + collator_kwargs[ + "bench_data_collator" + ] = transformers.DataCollatorForSeq2Seq( + self.tokenizer, + return_tensors="pt", + **data_collator_kwargs, + ) + + return collator_kwargs diff --git a/src/axolotl/models/llava/__init__.py b/src/axolotl/models/llava/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/models/llava/llava_mistral.py b/src/axolotl/models/llava/llava_mistral.py new file mode 100644 index 000000000..f6c85e457 --- /dev/null +++ b/src/axolotl/models/llava/llava_mistral.py @@ -0,0 +1,167 @@ +""" +LLaVA Mistral classes +""" + +from typing import List, Optional, Tuple, Union + +import torch +from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + MistralConfig, + MistralForCausalLM, + MistralModel, +) +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class LlavaMistralConfig(MistralConfig): + """ + HF Transformers Config for Mistral w LLaVA + """ + + model_type = "llava_mistral" + + +class LlavaMistralModel(LlavaMetaModel, MistralModel): + """ + HF Transformers Model for Mistral w LLaVA + """ + + config_class = LlavaMistralConfig + + def __init__( + self, config: LlavaMistralConfig + ): # pylint: disable=useless-parent-delegation + super().__init__(config) + + +class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): + """ + HF Transformers Causal Model for Mistral w LLaVA + """ + + config_class = LlavaMistralConfig + + def __init__(self, config: LlavaMistralConfig): + super().__init__(config) + self.model = LlavaMistralModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + ( + input_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, attention_mask, past_key_values, labels, images + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + +AutoConfig.register("llava_mistral", LlavaMistralConfig) +AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 468d25e14..54e8972e4 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -20,6 +20,14 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer +try: + from llava.train.train import safe_save_model_for_hf_trainer +except ImportError: + + def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + raise ImportError("missing LLaVA package") + + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) @@ -137,6 +145,8 @@ def train( # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: trainer.save_model(cfg.output_dir) + elif cfg.multimodal: + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=cfg.output_dir) elif cfg.deepspeed and is_deepspeed_zero3_enabled(): # Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading trainer.accelerator.wait_for_everyone() diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 81660ae65..083af0764 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -369,6 +369,15 @@ def validate_config(cfg): "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) + if cfg.multimodal: + try: + import llava # noqa: F401 # pylint:disable=unused-import + except ImportError as exc: + LOG.warning( + "LLaVA package required for multimodal training. See docs/llava.md for more information." + ) + raise exc + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ea21ce8f9..07fc7c989 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -255,7 +255,92 @@ def load_model( model_kwargs["use_flash_attention_2"] = True try: - if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: + if cfg.multimodal: + from llava.train.train import DataArguments, ModelArguments + + if cfg.is_llama_derived_model: + from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM + + model = LlavaLlamaForCausalLM.from_pretrained( + cfg.base_model, + ) + elif cfg.is_mistral_derived_model: + from axolotl.models.llava.llava_mistral import LlavaMistralForCausalLM + + model = LlavaMistralForCausalLM.from_pretrained( + cfg.base_model, + ) + else: + raise NotImplementedError( + "unhandled model architecture for multimodal training" + ) + + if cfg.mm_freeze_backbone: + model.model.requires_grad_(False) + + def make_inputs_require_grad( + module, input, output + ): # pylint: disable=redefined-builtin,unused-argument + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model_args = ModelArguments( + model_name_or_path=cfg.base_model, + version="v0", + freeze_backbone=cfg.mm_freeze_backbone or False, + tune_mm_mlp_adapter=cfg.tune_mm_mlp_adapter or False, + vision_tower=cfg.mm_vision_tower, + mm_vision_select_layer=cfg.mm_vision_select_layer or -1, + pretrain_mm_mlp_adapter=cfg.pretrain_mm_mlp_adapter, + mm_projector_type=cfg.mm_projector_type or "linear", + mm_use_im_start_end=cfg.mm_use_im_start_end or False, + mm_use_im_patch_token=cfg.mm_use_im_patch_token or True, + mm_vision_select_feature=cfg.mm_vision_select_feature or "patch", + ) + + if cfg.mm_vision_tower: + model.get_model().initialize_vision_modules( + model_args=model_args, fsdp=cfg.fsdp + ) + + vision_tower = model.get_vision_tower() + vision_tower.to(dtype=cfg.torch_dtype) + + data_args = DataArguments( + data_path=None, + lazy_preprocess=cfg.mm_lazy_preprocess + if cfg.mm_lazy_preprocess is not None + else True, + is_multimodal=True, + image_folder=None, + image_aspect_ratio="square", + image_grid_pinpoints=None, + ) + data_args.image_processor = vision_tower.image_processor + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for ( + p # pylint: disable=invalid-name + ) in model.get_model().mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = cfg.freeze_mm_mlp_adapter + if cfg.freeze_mm_mlp_adapter: + for ( + p # pylint: disable=invalid-name + ) in model.get_model().mm_projector.parameters(): + p.requires_grad = False + + model.config.mm_use_im_start_end = ( + data_args.mm_use_im_start_end + ) = model_args.mm_use_im_start_end + model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + elif cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM config_kwargs = {} @@ -520,7 +605,14 @@ def load_llama_adapter(model, cfg): def find_all_linear_names(model): cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) lora_module_names = set() + multimodal_keywords = [ + "mm_projector", + "vision_tower", + "vision_resampler", + ] # for LLaVA for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue if ( isinstance(module, cls) or "Linear" in module.__class__.__name__