more fixes to try to get mm working
This commit is contained in:
@@ -2,9 +2,10 @@ base_model: mistralai/Mistral-7B-v0.1
|
|||||||
model_type: MistralForCausalLM
|
model_type: MistralForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
is_mistral_derived_model: true
|
is_mistral_derived_model: true
|
||||||
multimodal: true
|
|
||||||
|
|
||||||
vision_tower: openai/clip-vit-large-patch14
|
# multimodal pretrain
|
||||||
|
multimodal: true
|
||||||
|
mm_vision_tower: openai/clip-vit-large-patch14
|
||||||
tune_mm_mlp_adapter: true
|
tune_mm_mlp_adapter: true
|
||||||
mm_vision_select_layer: -2
|
mm_vision_select_layer: -2
|
||||||
mm_projector_type: mlp2x_gelu
|
mm_projector_type: mlp2x_gelu
|
||||||
@@ -21,7 +22,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: false
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
@@ -237,10 +237,11 @@ def load_mm_dataset(
|
|||||||
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
|
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
|
||||||
)
|
)
|
||||||
data_args.image_processor = vision_tower.image_processor
|
data_args.image_processor = vision_tower.image_processor
|
||||||
|
data_args.mm_use_im_start_end = cfg.mm_use_im_start_end or False
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
train_dataset = LazySupervisedDataset(
|
train_dataset = LazySupervisedDataset(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_path=data_args["data_path"],
|
data_path=data_args.data_path,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
|
|
||||||
@@ -47,6 +48,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
dataset_meta = load_mm_dataset(
|
dataset_meta = load_mm_dataset(
|
||||||
cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
|
cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
|
||||||
)
|
)
|
||||||
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
if parsed_cli_args.prepare_ds_only:
|
if parsed_cli_args.prepare_ds_only:
|
||||||
return
|
return
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|||||||
Reference in New Issue
Block a user