From ab9d12ce345377e1adda2a14184d3e2caf001330 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 23 Oct 2023 21:44:07 -0400 Subject: [PATCH] handle dataset loading for multimodal --- examples/multimodal/llava-mistral.yml | 62 +++++++++++++++++++++++++++ src/axolotl/cli/__init__.py | 40 +++++++++++++++++ src/axolotl/cli/train_mm.py | 56 ++++++++++++++++++++++++ src/axolotl/utils/data.py | 15 ++++++- src/axolotl/utils/models.py | 9 ++-- 5 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 examples/multimodal/llava-mistral.yml create mode 100644 src/axolotl/cli/train_mm.py diff --git a/examples/multimodal/llava-mistral.yml b/examples/multimodal/llava-mistral.yml new file mode 100644 index 000000000..d38088347 --- /dev/null +++ b/examples/multimodal/llava-mistral.yml @@ -0,0 +1,62 @@ +base_model: mistralai/Mistral-7B-v0.1 +model_type: MistralForCausalLM +tokenizer_type: LlamaTokenizer +is_mistral_derived_model: true + +vision_tower: openai/clip-vit-large-patch14 +tune_mm_mlp_adapter: true +mm_vision_select_layer: -2 +mm_projector_type: mlp2x_gelu +mm_image_folder: ./llava/ + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: liuhaotian/LLaVA-CC3M-Pretrain-595K +dataset_prepared_path: +val_set_size: 0.01 +output_dir: ./out + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +warmup_steps: 10 +eval_steps: 0.05 +save_steps: +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "" diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 07a6209e4..680bc9746 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,6 +2,7 @@ import importlib import logging +import math import os import random import sys @@ -215,6 +216,45 @@ def load_cfg(config: Path = Path("examples/"), **kwargs): return cfg +def load_mm_dataset( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument + model, +): + # pylint: disable=duplicate-code + from llava.train.train import DataArguments, LazySupervisedDataset + + vision_tower = model.get_vision_tower() + data_args = DataArguments( + data_path=cfg.datasets[0]["path"], + lazy_preprocess=cfg.mm_lazy_preprocess + if cfg.mm_lazy_preprocess is not None + else True, + is_multimodal=True, + image_folder=cfg.mm_image_folder or None, + image_aspect_ratio=cfg.mm_image_aspect_ratio or "square", + image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None, + ) + data_args.image_processor = vision_tower.image_processor + tokenizer = load_tokenizer(cfg) + train_dataset = LazySupervisedDataset( + tokenizer=tokenizer, + data_path=data_args["data_path"], + data_args=data_args, + ) + + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=None, + total_num_steps=total_num_steps, + ) + + def load_datasets( *, cfg: DictDefault, diff --git a/src/axolotl/cli/train_mm.py b/src/axolotl/cli/train_mm.py new file mode 100644 index 000000000..9039d70c6 --- /dev/null +++ b/src/axolotl/cli/train_mm.py @@ -0,0 +1,56 @@ +""" +CLI to run training on a model +""" +import logging +from pathlib import Path + +import fire +import transformers +from colorama import Fore + +from axolotl.cli import ( + check_accelerate_default_config, + check_user_token, + load_cfg, + load_mm_dataset, + print_axolotl_text_art, +) +from axolotl.common.cli import TrainerCliArgs +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.train import train +from axolotl.utils.models import load_model, load_tokenizer + +LOG = logging.getLogger("axolotl.cli.train") + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + check_accelerate_default_config() + check_user_token() + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path: + msg = ( + Fore.RED + + "--prepare_ds_only called without dataset_prepared_path set." + + Fore.RESET + ) + LOG.warning(msg) + parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH + + tokenizer = load_tokenizer(parsed_cfg) + model = load_model(parsed_cfg, tokenizer) + dataset_meta = load_mm_dataset( + cfg=parsed_cfg, cli_args=parsed_cli_args, model=model + ) + if parsed_cli_args.prepare_ds_only: + return + train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + + +if __name__ == "__main__": + fire.Fire(do_cli) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 99697de32..a23164b7c 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -54,8 +54,19 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str: return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec -def prepare_dataset(cfg, tokenizer): - if not cfg.pretraining_dataset: +def prepare_dataset(cfg, tokenizer, model=None): + if cfg.multimodal: + if not model: + raise ValueError("missing model argument") + from llava.train.train import LazySupervisedDataset + + with zero_first(is_main_process()): + eval_dataset = None + train_dataset = LazySupervisedDataset( + tokenizer=tokenizer, + ) + + elif not cfg.pretraining_dataset: with zero_first(is_main_process()): train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 07fc7c989..a25d00716 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -307,15 +307,16 @@ def load_model( vision_tower = model.get_vision_tower() vision_tower.to(dtype=cfg.torch_dtype) + # pylint: disable=duplicate-code data_args = DataArguments( - data_path=None, + data_path=cfg.datasets[0]["path"], lazy_preprocess=cfg.mm_lazy_preprocess if cfg.mm_lazy_preprocess is not None else True, is_multimodal=True, - image_folder=None, - image_aspect_ratio="square", - image_grid_pinpoints=None, + image_folder=cfg.mm_image_folder or None, + image_aspect_ratio=cfg.mm_image_aspect_ratio or "square", + image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None, ) data_args.image_processor = vision_tower.image_processor model.config.image_aspect_ratio = data_args.image_aspect_ratio