From b52e61a574419800245980b9775839d3edf20dc4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Oct 2023 11:03:55 -0400 Subject: [PATCH] pretrain fixes for mm --- examples/multimodal/pretrain-llava-llama.yml | 4 +++- examples/multimodal/pretrain-llava-mistral.yml | 2 ++ src/axolotl/utils/models.py | 14 +++++++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/multimodal/pretrain-llava-llama.yml b/examples/multimodal/pretrain-llava-llama.yml index 533cec41f..e45612f4f 100644 --- a/examples/multimodal/pretrain-llava-llama.yml +++ b/examples/multimodal/pretrain-llava-llama.yml @@ -7,9 +7,11 @@ is_llama_derived_model: true multimodal: true mm_vision_tower: openai/clip-vit-large-patch14 tune_mm_mlp_adapter: true +mm_freeze_backbone: true mm_vision_select_layer: -2 mm_projector_type: mlp2x_gelu mm_image_folder: ./llava/ +mm_use_im_patch_token: false load_in_8bit: false load_in_4bit: false @@ -54,7 +56,7 @@ flash_attention: true warmup_steps: 10 eval_steps: -save_steps: +save_steps: 0.1 debug: deepspeed: weight_decay: 0.0 diff --git a/examples/multimodal/pretrain-llava-mistral.yml b/examples/multimodal/pretrain-llava-mistral.yml index ca113c168..d935c3821 100644 --- a/examples/multimodal/pretrain-llava-mistral.yml +++ b/examples/multimodal/pretrain-llava-mistral.yml @@ -7,9 +7,11 @@ is_mistral_derived_model: true multimodal: true mm_vision_tower: openai/clip-vit-large-patch14 tune_mm_mlp_adapter: true +mm_freeze_backbone: true mm_vision_select_layer: -2 mm_projector_type: mlp2x_gelu mm_image_folder: ./llava/ +mm_use_im_patch_token: false load_in_8bit: false load_in_4bit: false diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f3a68dc41..e5b57e1c4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -286,8 +286,8 @@ def load_model( def make_inputs_require_grad( module, input, - output, # pylint: disable=redefined-builtin,unused-argument - ): + output, + ): # pylint: disable=redefined-builtin,unused-argument output.requires_grad_(True) model.get_input_embeddings().register_forward_hook( @@ -304,7 +304,7 @@ def load_model( pretrain_mm_mlp_adapter=cfg.pretrain_mm_mlp_adapter, mm_projector_type=cfg.mm_projector_type or "linear", mm_use_im_start_end=cfg.mm_use_im_start_end or False, - mm_use_im_patch_token=cfg.mm_use_im_patch_token or True, + mm_use_im_patch_token=cfg.mm_use_im_patch_token, mm_vision_select_feature=cfg.mm_vision_select_feature or "patch", ) @@ -330,8 +330,8 @@ def load_model( data_args.image_processor = vision_tower.image_processor model.config.image_aspect_ratio = data_args.image_aspect_ratio model.config.image_grid_pinpoints = data_args.image_grid_pinpoints - model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter - if model_args.tune_mm_mlp_adapter: + model.config.tune_mm_mlp_adapter = cfg.tune_mm_mlp_adapter + if cfg.tune_mm_mlp_adapter: model.requires_grad_(False) for ( p # pylint: disable=invalid-name @@ -347,8 +347,8 @@ def load_model( model.config.mm_use_im_start_end = ( data_args.mm_use_im_start_end - ) = model_args.mm_use_im_start_end - model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token + ) = cfg.mm_use_im_start_end + model.config.mm_use_im_patch_token = cfg.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) elif cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: from transformers import LlamaForCausalLM