From fdc3e4d505820f8d45d1696e1c4a7e37840a7ef7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Oct 2023 23:15:33 -0400 Subject: [PATCH] more fixes to try to get mm working --- .../{llava-mistral.yml => pretrain-llava-mistral.yml} | 7 ++++--- src/axolotl/cli/__init__.py | 3 ++- src/axolotl/cli/train_mm.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) rename examples/multimodal/{llava-mistral.yml => pretrain-llava-mistral.yml} (92%) diff --git a/examples/multimodal/llava-mistral.yml b/examples/multimodal/pretrain-llava-mistral.yml similarity index 92% rename from examples/multimodal/llava-mistral.yml rename to examples/multimodal/pretrain-llava-mistral.yml index 9f371191f..f03ae28d2 100644 --- a/examples/multimodal/llava-mistral.yml +++ b/examples/multimodal/pretrain-llava-mistral.yml @@ -2,9 +2,10 @@ base_model: mistralai/Mistral-7B-v0.1 model_type: MistralForCausalLM tokenizer_type: LlamaTokenizer is_mistral_derived_model: true -multimodal: true -vision_tower: openai/clip-vit-large-patch14 +# multimodal pretrain +multimodal: true +mm_vision_tower: openai/clip-vit-large-patch14 tune_mm_mlp_adapter: true mm_vision_select_layer: -2 mm_projector_type: mlp2x_gelu @@ -21,7 +22,7 @@ val_set_size: 0.01 output_dir: ./out sequence_len: 2048 -sample_packing: true +sample_packing: false pad_to_sequence_len: true wandb_project: diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 680bc9746..8c272270e 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -237,10 +237,11 @@ def load_mm_dataset( image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None, ) data_args.image_processor = vision_tower.image_processor + data_args.mm_use_im_start_end = cfg.mm_use_im_start_end or False tokenizer = load_tokenizer(cfg) train_dataset = LazySupervisedDataset( tokenizer=tokenizer, - data_path=data_args["data_path"], + data_path=data_args.data_path, data_args=data_args, ) diff --git a/src/axolotl/cli/train_mm.py b/src/axolotl/cli/train_mm.py index ea3be1092..50582e88d 100644 --- a/src/axolotl/cli/train_mm.py +++ b/src/axolotl/cli/train_mm.py @@ -5,6 +5,7 @@ import logging from pathlib import Path import fire +import torch import transformers from colorama import Fore @@ -47,6 +48,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs): dataset_meta = load_mm_dataset( cfg=parsed_cfg, cli_args=parsed_cli_args, model=model ) + del model + torch.cuda.empty_cache() if parsed_cli_args.prepare_ds_only: return train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)