Create preprocess CLI (#785)
* Create preprocess CLI * Print prompt template if debugging * Add print for unsupported prompters * Formatting * Formatting * Refactor variables * Formatting * Formatting * Formatting * Formatting
This commit is contained in:
54
README.md
54
README.md
@@ -32,7 +32,6 @@ Features:
|
|||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
- [Config](#config)
|
- [Config](#config)
|
||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
@@ -824,14 +823,41 @@ Run
|
|||||||
accelerate launch -m axolotl.cli.train your_config.yml
|
accelerate launch -m axolotl.cli.train your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Multi-GPU
|
#### Preprocess dataset
|
||||||
|
|
||||||
|
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||||
|
This is recommended for large datasets.
|
||||||
|
|
||||||
|
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||||
|
- Use `--debug` to see preprocessed examples.
|
||||||
|
|
||||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
python -m axolotl.cli.preprocess your_config.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
##### Config
|
#### Multi-GPU
|
||||||
|
|
||||||
|
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||||
|
is the recommended multi-GPU option currently because FSDP may experience
|
||||||
|
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||||
|
|
||||||
|
##### DeepSpeed
|
||||||
|
|
||||||
|
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||||
|
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||||
|
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||||
|
|
||||||
|
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
deepspeed: deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||||
|
```
|
||||||
|
|
||||||
|
##### FSDP
|
||||||
|
|
||||||
- llama FSDP
|
- llama FSDP
|
||||||
```yaml
|
```yaml
|
||||||
@@ -856,24 +882,6 @@ wandb_run_id:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training with Deepspeed
|
|
||||||
|
|
||||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
|
||||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
|
||||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
|
||||||
|
|
||||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
|
||||||
```
|
|
||||||
|
|
||||||
or
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
deepspeed: deepspeed/zero1.json
|
|
||||||
```
|
|
||||||
|
|
||||||
### Inference
|
### Inference
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
Pass the appropriate flag to the train command:
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
|
||||||
return
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,9 @@ def load_datasets(
|
|||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||||
|
cfg, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
if cli_args.debug or cfg.debug:
|
if cli_args.debug or cfg.debug:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
@@ -238,6 +240,10 @@ def load_datasets(
|
|||||||
text_only=cli_args.debug_text_only,
|
text_only=cli_args.debug_text_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG.info("printing prompters...")
|
||||||
|
for prompter in prompters:
|
||||||
|
LOG.info(prompter)
|
||||||
|
|
||||||
return TrainDatasetMeta(
|
return TrainDatasetMeta(
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
|
|||||||
53
src/axolotl/cli/preprocess.py
Normal file
53
src/axolotl/cli/preprocess.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
CLI to run training on a model
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import transformers
|
||||||
|
from colorama import Fore
|
||||||
|
|
||||||
|
from axolotl.cli import (
|
||||||
|
check_accelerate_default_config,
|
||||||
|
check_user_token,
|
||||||
|
load_cfg,
|
||||||
|
load_datasets,
|
||||||
|
print_axolotl_text_art,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
check_accelerate_default_config()
|
||||||
|
check_user_token()
|
||||||
|
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
|
msg = (
|
||||||
|
Fore.RED
|
||||||
|
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||||
|
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
LOG.warning(msg)
|
||||||
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
|
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -16,7 +15,6 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
|
||||||
msg = (
|
|
||||||
Fore.RED
|
|
||||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
LOG.warning(msg)
|
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
|
||||||
return
|
|
||||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,22 @@ class TrainerCliArgs:
|
|||||||
debug_num_examples: int = field(default=5)
|
debug_num_examples: int = field(default=5)
|
||||||
inference: bool = field(default=False)
|
inference: bool = field(default=False)
|
||||||
merge_lora: bool = field(default=False)
|
merge_lora: bool = field(default=False)
|
||||||
prepare_ds_only: bool = field(default=False)
|
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
shard: bool = field(default=False)
|
shard: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessCliArgs:
|
||||||
|
"""
|
||||||
|
dataclass representing arguments for preprocessing only
|
||||||
|
"""
|
||||||
|
|
||||||
|
debug: bool = field(default=False)
|
||||||
|
debug_text_only: bool = field(default=False)
|
||||||
|
debug_num_examples: int = field(default=1)
|
||||||
|
prompter: Optional[str] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
@@ -245,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
(
|
(
|
||||||
instruction,
|
instruction,
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
|
from colorama import Fore
|
||||||
from fastchat.conversation import Conversation, get_conv_template
|
from fastchat.conversation import Conversation, get_conv_template
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(Enum):
|
class PromptStyle(Enum):
|
||||||
@@ -55,20 +57,15 @@ class AlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
|
|
||||||
def build_prompt(
|
def _build_result(self, instruction, input_text, output):
|
||||||
self,
|
|
||||||
instruction: str,
|
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
|
||||||
output: Union[None, str] = None,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input_text:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_prompt)
|
self.system_format.format(system=self.system_prompt)
|
||||||
if self.system_prompt
|
if self.system_prompt
|
||||||
else ""
|
else ""
|
||||||
) + self.turn_format.format(instruction=instruction, input=input)
|
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||||
else:
|
else:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
@@ -77,7 +74,21 @@ class AlpacaPrompter:
|
|||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
res = f"{res}{output}"
|
res = f"{res}{output}"
|
||||||
yield res
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
|
output: Union[None, str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
yield self._build_result(instruction, input, output)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return REPR_TEMPLATE.format(
|
||||||
|
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnpromptedPrompter(AlpacaPrompter):
|
class UnpromptedPrompter(AlpacaPrompter):
|
||||||
@@ -191,14 +202,14 @@ class ReflectAlpacaPrompter:
|
|||||||
)
|
)
|
||||||
self.response_split = "ASSISTANT:"
|
self.response_split = "ASSISTANT:"
|
||||||
|
|
||||||
def build_prompt(
|
def _build_result(
|
||||||
self,
|
self,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
output: Union[None, str] = None,
|
output: Union[None, str] = None,
|
||||||
reflection: Union[None, str] = None,
|
reflection: Union[None, str] = None,
|
||||||
corrected: Union[None, str] = None,
|
corrected: Union[None, str] = None,
|
||||||
) -> Generator[str, None, None]:
|
):
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input:
|
||||||
@@ -212,7 +223,30 @@ class ReflectAlpacaPrompter:
|
|||||||
corrected=corrected,
|
corrected=corrected,
|
||||||
)
|
)
|
||||||
res = f"{res}{label}"
|
res = f"{res}{label}"
|
||||||
yield res
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||||
|
output: Union[None, str] = None,
|
||||||
|
reflection: Union[None, str] = None,
|
||||||
|
corrected: Union[None, str] = None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
yield self._build_result(
|
||||||
|
instruction,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
reflection,
|
||||||
|
corrected,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return REPR_TEMPLATE.format(
|
||||||
|
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
@@ -247,7 +281,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if role_key_model:
|
if role_key_model:
|
||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
def _build_result(self, source):
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
@@ -282,11 +316,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_turns():
|
return conv.get_turns()
|
||||||
|
|
||||||
|
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||||
|
turns = self._build_result(source)
|
||||||
|
|
||||||
|
for part in turns:
|
||||||
if part[0] and not part[1]:
|
if part[0] and not part[1]:
|
||||||
LOG.warning(f"role with empty message: {part[0]}")
|
LOG.warning(f"role with empty message: {part[0]}")
|
||||||
yield part
|
yield part
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||||
|
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||||
"""
|
"""
|
||||||
@@ -304,3 +347,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
role_key_human=role_key_human,
|
role_key_human=role_key_human,
|
||||||
role_key_model=role_key_model,
|
role_key_model=role_key_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedPrompter:
|
||||||
|
"""
|
||||||
|
A dummy class for custom prompters
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -36,6 +36,7 @@ from axolotl.prompters import (
|
|||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
@@ -55,9 +56,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -70,7 +72,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
return train_dataset, eval_dataset, cfg.max_steps
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
@@ -83,7 +85,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
return train_dataset, eval_dataset, total_num_steps
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
else Path(default_dataset_prepared_path) / ds_hash
|
else Path(default_dataset_prepared_path) / ds_hash
|
||||||
)
|
)
|
||||||
dataset = None
|
dataset = None
|
||||||
|
prompters = []
|
||||||
use_auth_token = cfg.hf_use_auth_token
|
use_auth_token = cfg.hf_use_auth_token
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
@@ -147,13 +150,13 @@ def load_tokenized_prepared_datasets(
|
|||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for d in for_d_in_datasets(cfg.datasets):
|
for config_dataset in for_d_in_datasets(cfg.datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
load_dataset(
|
load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
@@ -162,33 +165,33 @@ def load_tokenized_prepared_datasets(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(d.path)
|
local_path = Path(config_dataset.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
data_files=d.data_files,
|
data_files=config_dataset.data_files,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = "json"
|
||||||
if d.ds_type:
|
if config_dataset.ds_type:
|
||||||
ds_type = d.ds_type
|
ds_type = config_dataset.ds_type
|
||||||
elif ".parquet" in d.path:
|
elif ".parquet" in config_dataset.path:
|
||||||
ds_type = "parquet"
|
ds_type = "parquet"
|
||||||
elif ".arrow" in d.path:
|
elif ".arrow" in config_dataset.path:
|
||||||
ds_type = "arrow"
|
ds_type = "arrow"
|
||||||
elif ".csv" in d.path:
|
elif ".csv" in config_dataset.path:
|
||||||
ds_type = "csv"
|
ds_type = "csv"
|
||||||
elif ".txt" in d.path:
|
elif ".txt" in config_dataset.path:
|
||||||
ds_type = "text"
|
ds_type = "text"
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
ds_type,
|
ds_type,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
data_files=d.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
@@ -198,25 +201,25 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path,
|
config_dataset.path,
|
||||||
name=d.name,
|
name=config_dataset.name,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
data_files=d.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(d.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
repo_id=d.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=d.data_files,
|
filename=config_dataset.data_files,
|
||||||
)
|
)
|
||||||
elif isinstance(d.data_files, list):
|
elif isinstance(config_dataset.data_files, list):
|
||||||
fp = []
|
fp = []
|
||||||
for file in d.data_files:
|
for file in config_dataset.data_files:
|
||||||
fp.append(
|
fp.append(
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
repo_id=d.path,
|
repo_id=config_dataset.path,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
filename=file,
|
filename=file,
|
||||||
)
|
)
|
||||||
@@ -226,21 +229,27 @@ def load_tokenized_prepared_datasets(
|
|||||||
"data_files must be either a string or list of strings"
|
"data_files must be either a string or list of strings"
|
||||||
)
|
)
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
"json",
|
||||||
|
name=config_dataset.name,
|
||||||
|
data_files=fp,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise ValueError("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if config_dataset.shards:
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||||
num_shards=d.shards, index=0
|
num_shards=config_dataset.shards, index=0
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
ds = ds.shuffle(seed=seed).shard(
|
||||||
|
num_shards=config_dataset.shards, index=0
|
||||||
|
)
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
d_type = d.type
|
d_type = config_dataset.type
|
||||||
if isinstance(d_type, str):
|
if isinstance(d_type, str):
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
@@ -249,108 +258,26 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
elif (
|
elif (
|
||||||
isinstance(ds, DatasetDict)
|
isinstance(ds, DatasetDict)
|
||||||
and d.train_on_split
|
and config_dataset.train_on_split
|
||||||
and d.train_on_split in ds
|
and config_dataset.train_on_split in ds
|
||||||
):
|
):
|
||||||
ds = ds[d.train_on_split]
|
ds = ds[config_dataset.train_on_split]
|
||||||
elif isinstance(ds, DatasetDict):
|
elif isinstance(ds, DatasetDict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
||||||
)
|
|
||||||
if (
|
|
||||||
"input_ids" in ds.features
|
|
||||||
and "attention_mask" in ds.features
|
|
||||||
and "labels" in ds.features
|
|
||||||
):
|
|
||||||
# dataset is already tokenized, just drop it straight in
|
|
||||||
datasets.append(ds)
|
|
||||||
elif isinstance(d.type, DictDefault):
|
|
||||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "alpaca":
|
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
||||||
AlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "explainchoice":
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
MultipleChoiceExplainPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "concisechoice":
|
|
||||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
||||||
MultipleChoiceConcisePrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "summarizetldr":
|
|
||||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
|
||||||
SummarizeTLDRPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "jeopardy":
|
|
||||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
|
||||||
JeopardyPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "oasst":
|
|
||||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
|
||||||
AlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "gpteacher":
|
|
||||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
|
||||||
GPTeacherPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
elif d_base_type == "reflection":
|
|
||||||
ds_strategy = AlpacaReflectionPTStrategy(
|
|
||||||
ReflectAlpacaPrompter(d_prompt_style),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
else:
|
|
||||||
suffix = ""
|
|
||||||
if ":load_" in d.type:
|
|
||||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
|
||||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
|
||||||
raise ValueError(
|
|
||||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||||
|
config_dataset=config_dataset,
|
||||||
|
dataset=ds,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
d_base_type=d_base_type,
|
||||||
|
d_prompt_style=d_prompt_style,
|
||||||
|
)
|
||||||
|
datasets.append(dataset_wrapper)
|
||||||
|
prompters.append(dataset_prompter)
|
||||||
|
|
||||||
LOG.info("merging datasets")
|
LOG.info("merging datasets")
|
||||||
dataset = concatenate_datasets(datasets)
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
@@ -368,14 +295,14 @@ def load_tokenized_prepared_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset]:
|
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -384,6 +311,7 @@ def load_prepare_datasets(
|
|||||||
) # make sure we don't accidentally set it larger than sequence_len
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
|
prompters = []
|
||||||
if cfg.max_packed_sequence_len is not None:
|
if cfg.max_packed_sequence_len is not None:
|
||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
@@ -439,7 +367,7 @@ def load_prepare_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -481,7 +409,7 @@ def load_prepare_datasets(
|
|||||||
private=True,
|
private=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -532,7 +460,124 @@ def load_prepare_datasets(
|
|||||||
train_dataset = dataset
|
train_dataset = dataset
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_wrapper(
|
||||||
|
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||||
|
):
|
||||||
|
dataset_wrapper = None
|
||||||
|
dataset_prompter = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
"input_ids" in dataset.features
|
||||||
|
and "attention_mask" in dataset.features
|
||||||
|
and "labels" in dataset.features
|
||||||
|
):
|
||||||
|
# dataset is already tokenized, just drop it straight in
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = dataset
|
||||||
|
elif isinstance(config_dataset.type, DictDefault):
|
||||||
|
ds_strategy = load(
|
||||||
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
|
)
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
|
dataset_prompter = UnsupportedPrompter()
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
elif d_base_type == "alpaca":
|
||||||
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "explainchoice":
|
||||||
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "concisechoice":
|
||||||
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "summarizetldr":
|
||||||
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||||
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "jeopardy":
|
||||||
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||||
|
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "oasst":
|
||||||
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "gpteacher":
|
||||||
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||||
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
elif d_base_type == "reflection":
|
||||||
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||||
|
ds_strategy = AlpacaReflectionPTStrategy(
|
||||||
|
dataset_prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||||
|
dataset_wrapper = ds_wrapper
|
||||||
|
else:
|
||||||
|
suffix = ""
|
||||||
|
if ":load_" in config_dataset.type:
|
||||||
|
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
||||||
|
LOG.error(
|
||||||
|
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
|
|||||||
Reference in New Issue
Block a user