diff --git a/examples/multimodal/pretrain-llava-llama.yml b/examples/multimodal/pretrain-llava-llama.yml index f03ae28d2..855629ee4 100644 --- a/examples/multimodal/pretrain-llava-llama.yml +++ b/examples/multimodal/pretrain-llava-llama.yml @@ -1,7 +1,7 @@ -base_model: mistralai/Mistral-7B-v0.1 -model_type: MistralForCausalLM +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer -is_mistral_derived_model: true +is_llama_derived_model: true # multimodal pretrain multimodal: true @@ -16,9 +16,10 @@ load_in_4bit: false strict: false datasets: - - path: liuhaotian/LLaVA-CC3M-Pretrain-595K + - path: ./data/blip_laion_cc_sbu_558k.json +# - path: liuhaotian/LLaVA-CC3M-Pretrain-595K dataset_prepared_path: -val_set_size: 0.01 +val_set_size: 0.0 output_dir: ./out sequence_len: 2048 @@ -33,8 +34,8 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit +num_epochs: 1 +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.002 diff --git a/examples/multimodal/pretrain-llava-mistral.yml b/examples/multimodal/pretrain-llava-mistral.yml index f03ae28d2..2d85a480c 100644 --- a/examples/multimodal/pretrain-llava-mistral.yml +++ b/examples/multimodal/pretrain-llava-mistral.yml @@ -18,7 +18,7 @@ strict: false datasets: - path: liuhaotian/LLaVA-CC3M-Pretrain-595K dataset_prepared_path: -val_set_size: 0.01 +val_set_size: 0.0 output_dir: ./out sequence_len: 2048 @@ -33,8 +33,8 @@ wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 -num_epochs: 4 -optimizer: adamw_bnb_8bit +num_epochs: 1 +optimizer: adamw_torch lr_scheduler: cosine learning_rate: 0.002 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a25d00716..f3a68dc41 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -278,12 +278,21 @@ def load_model( if cfg.mm_freeze_backbone: model.model.requires_grad_(False) - def make_inputs_require_grad( - module, input, output - ): # pylint: disable=redefined-builtin,unused-argument - output.requires_grad_(True) + if cfg.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + def make_inputs_require_grad( + module, + input, + output, # pylint: disable=redefined-builtin,unused-argument + ): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) model_args = ModelArguments( model_name_or_path=cfg.base_model, @@ -299,13 +308,13 @@ def load_model( mm_vision_select_feature=cfg.mm_vision_select_feature or "patch", ) - if cfg.mm_vision_tower: + if cfg.mm_vision_tower is not None: model.get_model().initialize_vision_modules( model_args=model_args, fsdp=cfg.fsdp ) vision_tower = model.get_vision_tower() - vision_tower.to(dtype=cfg.torch_dtype) + vision_tower.to(dtype=cfg.torch_dtype, device=cfg.device) # pylint: disable=duplicate-code data_args = DataArguments(