Reward model (#1879)

This commit is contained in:
Wing Lian
2024-10-13 15:11:13 -04:00
committed by GitHub
parent cd2d89f467
commit 68b1369de9
12 changed files with 382 additions and 21 deletions

View File

@@ -19,6 +19,7 @@ from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
@@ -459,7 +460,7 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if not cfg.skip_prepare_dataset:
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
@@ -609,7 +610,20 @@ def get_dataset_wrapper(
)
elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
elif ds_strategy := config_dataset.type.startswith(
"bradley_terry"
) and bradley_terry_load(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
if isinstance(ds_strategy, DatasetWrappingStrategy):
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else: