Create preprocess CLI (#785)

* Create preprocess CLI

* Print prompt template if debugging

* Add print for unsupported prompters

* Formatting

* Formatting

* Refactor variables

* Formatting

* Formatting

* Formatting

* Formatting
This commit is contained in:
Casper
2023-10-26 15:35:42 +02:00
committed by GitHub
parent 05bd6f1122
commit e50ab072e2
9 changed files with 354 additions and 190 deletions

View File

@@ -32,7 +32,6 @@ Features:
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
@@ -824,14 +823,41 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
#### Multi-GPU #### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.
You can optionally pre-tokenize dataset with the following before finetuning:
```bash ```bash
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only python -m axolotl.cli.preprocess your_config.yml
``` ```
##### Config #### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
##### FSDP
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -856,24 +882,6 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```
### Training with Deepspeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
or
```yaml
deepspeed: deepspeed/zero1.json
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:

View File

@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -222,7 +222,9 @@ def load_datasets(
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -238,6 +240,10 @@ def load_datasets(
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
) )
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

View File

@@ -0,0 +1,53 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from colorama import Fore
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -16,7 +15,6 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -25,11 +25,22 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -245,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
( (
instruction, instruction,
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin

View File

@@ -4,10 +4,12 @@ import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
class PromptStyle(Enum): class PromptStyle(Enum):
@@ -55,20 +57,15 @@ class AlpacaPrompter:
) )
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def build_prompt( def _build_result(self, instruction, input_text, output):
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input_text:
res = ( res = (
self.system_format.format(system=self.system_prompt) self.system_format.format(system=self.system_prompt)
if self.system_prompt if self.system_prompt
else "" else ""
) + self.turn_format.format(instruction=instruction, input=input) ) + self.turn_format.format(instruction=instruction, input=input_text)
else: else:
res = ( res = (
self.system_format.format(system=self.system_no_input_prompt) self.system_format.format(system=self.system_no_input_prompt)
@@ -77,7 +74,21 @@ class AlpacaPrompter:
) + self.turn_no_input_format.format(instruction=instruction) ) + self.turn_no_input_format.format(instruction=instruction)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
) )
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def build_prompt( def _build_result(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
reflection: Union[None, str] = None, reflection: Union[None, str] = None,
corrected: Union[None, str] = None, corrected: Union[None, str] = None,
) -> Generator[str, None, None]: ):
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
corrected=corrected, corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code
yield self._build_result(
instruction,
input,
output,
reflection,
corrected,
)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if role_key_model: if role_key_model:
self.role_key_model = role_key_model self.role_key_model = role_key_model
def build_prompt(self, source) -> Generator[str, None, None]: def _build_result(self, source):
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_turns(): return conv.get_turns()
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]: if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}") LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter): class ShareGPTPrompterV2(ShareGPTPrompter):
""" """
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
role_key_human=role_key_human, role_key_human=role_key_human,
role_key_model=role_key_model, role_key_model=role_key_model,
) )
class UnsupportedPrompter:
"""
A dummy class for custom prompters
"""
def __init__(self) -> None:
pass
def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"

View File

@@ -3,7 +3,7 @@ import functools
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -36,6 +36,7 @@ from axolotl.prompters import (
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
@@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps return train_dataset, eval_dataset, cfg.max_steps, prompters
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -83,7 +85,7 @@ def prepare_dataset(cfg, tokenizer):
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps return train_dataset, eval_dataset, total_num_steps, prompters
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
else Path(default_dataset_prepared_path) / ds_hash else Path(default_dataset_prepared_path) / ds_hash
) )
dataset = None dataset = None
prompters = []
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
@@ -147,13 +150,13 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for d in for_d_in_datasets(cfg.datasets): for config_dataset in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
d.path, config_dataset.path,
name=d.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
) )
@@ -162,33 +165,33 @@ def load_tokenized_prepared_datasets(
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset( ds = load_dataset(
d.path, config_dataset.path,
name=d.name, name=config_dataset.name,
data_files=d.data_files, data_files=config_dataset.data_files,
streaming=False, streaming=False,
split=None, split=None,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = "json"
if d.ds_type: if config_dataset.ds_type:
ds_type = d.ds_type ds_type = config_dataset.ds_type
elif ".parquet" in d.path: elif ".parquet" in config_dataset.path:
ds_type = "parquet" ds_type = "parquet"
elif ".arrow" in d.path: elif ".arrow" in config_dataset.path:
ds_type = "arrow" ds_type = "arrow"
elif ".csv" in d.path: elif ".csv" in config_dataset.path:
ds_type = "csv" ds_type = "csv"
elif ".txt" in d.path: elif ".txt" in config_dataset.path:
ds_type = "text" ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=d.name, name=config_dataset.name,
data_files=d.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None, split=None,
) )
@@ -198,25 +201,25 @@ def load_tokenized_prepared_datasets(
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset( ds = load_dataset(
d.path, config_dataset.path,
name=d.name, name=config_dataset.name,
streaming=False, streaming=False,
data_files=d.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
) )
else: else:
if isinstance(d.data_files, str): if isinstance(config_dataset.data_files, str):
fp = hf_hub_download( fp = hf_hub_download(
repo_id=d.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=d.data_files, filename=config_dataset.data_files,
) )
elif isinstance(d.data_files, list): elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
for file in d.data_files: for file in config_dataset.data_files:
fp.append( fp.append(
hf_hub_download( hf_hub_download(
repo_id=d.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
) )
@@ -226,21 +229,27 @@ def load_tokenized_prepared_datasets(
"data_files must be either a string or list of strings" "data_files must be either a string or list of strings"
) )
ds = load_dataset( ds = load_dataset(
"json", name=d.name, data_files=fp, streaming=False, split=None "json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if config_dataset.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0 num_shards=config_dataset.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0) ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = d.type d_type = config_dataset.type
if isinstance(d_type, str): if isinstance(d_type, str):
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -249,108 +258,26 @@ def load_tokenized_prepared_datasets(
ds = ds["train"] ds = ds["train"]
elif ( elif (
isinstance(ds, DatasetDict) isinstance(ds, DatasetDict)
and d.train_on_split and config_dataset.train_on_split
and d.train_on_split in ds and config_dataset.train_on_split in ds
): ):
ds = ds[d.train_on_split] ds = ds[config_dataset.train_on_split]
elif isinstance(ds, DatasetDict): elif isinstance(ds, DatasetDict):
raise ValueError( raise ValueError(
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `" f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
LOG.info("merging datasets") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
@@ -368,14 +295,14 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
return dataset return dataset, prompters
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset]: ) -> Tuple[Dataset, Dataset, List[Any]]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -384,6 +311,7 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -439,7 +367,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -481,7 +409,7 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -532,7 +460,124 @@ def load_prepare_datasets(
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset, prompters
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
# dataset is already tokenized, just drop it straight in
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = dataset
elif isinstance(config_dataset.type, DictDefault):
ds_strategy = load(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
ds_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
ds_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
else:
suffix = ""
if ":load_" in config_dataset.type:
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
LOG.error(
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
)
return dataset_wrapper, dataset_prompter
def encode_pretraining( def encode_pretraining(