RL/DPO (#935)
* ipo-dpo trainer * fix missing abstract method * chatml template, grad checkpointing kwargs support * fix steps calc for RL and add dataloader kwargs * wip to fix dpo and start ppo * more fixes * refactor to generalize map fn * fix dataset loop and handle argilla pref dataset * set training args * load reference model on seperate gpu if more than one device * no auto upload to hub for dpo, don't add lora adapters to ref model for dpo * fixes for rl training * support for ipo from yaml * set dpo training args from the config, add tests * chore: lint * set sequence_len for model in test * add RLHF docs
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
@@ -16,6 +17,7 @@ import yaml
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
@@ -325,6 +327,94 @@ def load_datasets(
|
||||
)
|
||||
|
||||
|
||||
def load_rl_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
|
||||
) -> TrainDatasetMeta:
|
||||
train_datasets: List[Any] = []
|
||||
for i, ds_cfg in enumerate(cfg.datasets):
|
||||
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
|
||||
# eval_dataset = load_dataset(
|
||||
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
|
||||
# )
|
||||
eval_dataset = None
|
||||
|
||||
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
for i, data_set in enumerate(train_datasets):
|
||||
_type = cfg.datasets[i]["type"]
|
||||
ds_type_fn = locals()[_type]
|
||||
train_datasets[i] = data_set.map(ds_type_fn)
|
||||
train_dataset = concatenate_datasets(train_datasets)
|
||||
|
||||
# eval_dataset = eval_dataset.map(intel_apply_chatml)
|
||||
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
|
||||
def check_accelerate_default_config():
|
||||
if Path(config_args.default_yaml_config_file).exists():
|
||||
LOG.warning(
|
||||
|
||||
Reference in New Issue
Block a user