more fixes to try to get mm working

This commit is contained in:
Wing Lian
2023-10-23 23:15:33 -04:00
parent b885169229
commit fdc3e4d505
3 changed files with 9 additions and 4 deletions

View File

@@ -2,9 +2,10 @@ base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
multimodal: true
vision_tower: openai/clip-vit-large-patch14
# multimodal pretrain
multimodal: true
mm_vision_tower: openai/clip-vit-large-patch14
tune_mm_mlp_adapter: true
mm_vision_select_layer: -2
mm_projector_type: mlp2x_gelu
@@ -21,7 +22,7 @@ val_set_size: 0.01
output_dir: ./out
sequence_len: 2048
sample_packing: true
sample_packing: false
pad_to_sequence_len: true
wandb_project:

View File

@@ -237,10 +237,11 @@ def load_mm_dataset(
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
)
data_args.image_processor = vision_tower.image_processor
data_args.mm_use_im_start_end = cfg.mm_use_im_start_end or False
tokenizer = load_tokenizer(cfg)
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer,
data_path=data_args["data_path"],
data_path=data_args.data_path,
data_args=data_args,
)

View File

@@ -5,6 +5,7 @@ import logging
from pathlib import Path
import fire
import torch
import transformers
from colorama import Fore
@@ -47,6 +48,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
dataset_meta = load_mm_dataset(
cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
)
del model
torch.cuda.empty_cache()
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)