Multimodal Vision Llama - rudimentary support (#1940)
--------- Co-authored-by: Sunny <sunny@Sunnys-MacBook-Air.local> Co-authored-by: sunny <sunnyliu19981005@gmail.com>
This commit is contained in:
@@ -51,20 +51,31 @@ from axolotl.utils.trainer import (
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_local_main_process()):
|
||||
if cfg.test_datasets:
|
||||
train_dataset, _, prompters = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="train",
|
||||
processor=processor,
|
||||
)
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
)
|
||||
else:
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
tokenizer,
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
processor=processor,
|
||||
)
|
||||
else:
|
||||
path = cfg.pretraining_dataset
|
||||
@@ -123,6 +134,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
@@ -180,6 +192,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.dataset_prepared_path
|
||||
and any(prepared_ds_path.glob("*"))
|
||||
and not cfg.is_preprocess
|
||||
and not cfg.skip_prepare_dataset
|
||||
):
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
@@ -423,12 +436,16 @@ def load_tokenized_prepared_datasets(
|
||||
dataset=ds,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
datasets.append(dataset_wrapper)
|
||||
prompters.append(dataset_prompter)
|
||||
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
if len(datasets) == 1:
|
||||
dataset = datasets[0]
|
||||
else:
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
if len(datasets) > 1:
|
||||
if cfg.shuffle_merged_datasets:
|
||||
@@ -437,9 +454,10 @@ def load_tokenized_prepared_datasets(
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
if not cfg.skip_prepare_dataset:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if cfg.push_dataset_to_hub:
|
||||
@@ -478,9 +496,14 @@ def load_prepare_datasets(
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path, split=split
|
||||
tokenizer,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
split=split,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
@@ -546,6 +569,7 @@ def get_dataset_wrapper(
|
||||
d_base_type,
|
||||
dataset,
|
||||
d_prompt_style=None,
|
||||
processor=None,
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
@@ -578,7 +602,11 @@ def get_dataset_wrapper(
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
)
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
elif cfg.skip_prepare_dataset:
|
||||
dataset_wrapper = dataset
|
||||
elif ds_strategy := load(
|
||||
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
|
||||
Reference in New Issue
Block a user