* ipo-dpo trainer

* fix missing abstract method

* chatml template, grad checkpointing kwargs support

* fix steps calc for RL and add dataloader kwargs

* wip to fix dpo and start ppo

* more fixes

* refactor to generalize map fn

* fix dataset loop and handle argilla pref dataset

* set training args

* load reference model on seperate gpu if more than one device

* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo

* fixes for rl training

* support for ipo from yaml

* set dpo training args from the config, add tests

* chore: lint

* set sequence_len for model in test

* add RLHF docs
This commit is contained in:
Wing Lian
2024-01-04 18:21:25 -05:00
parent 59b2d302c8
commit f243c2186d
11 changed files with 388 additions and 6 deletions

35
docs/rlhf.md Normal file
View File

@@ -0,0 +1,35 @@
# RLHF (Beta)
### Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
### RLHF using Axolotl
[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
#### DPO
```yaml
rl: true
datasets:
- path: Intel/orca_dpo_pairs
split: train
type: intel_apply_chatml
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: argilla_apply_chatml
```
#### IPO
```yaml
rl: ipo
```

View File

@@ -37,3 +37,5 @@ tensorboard
s3fs
gcsfs
# adlfs
trl @ git+https://github.com/huggingface/trl.git@main

View File

@@ -2,6 +2,7 @@
import importlib
import logging
import math
import os
import random
import sys
@@ -16,6 +17,7 @@ import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
@@ -325,6 +327,94 @@ def load_datasets(
)
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
# eval_dataset = load_dataset(
# cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
# )
eval_dataset = None
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(ds_type_fn)
train_dataset = concatenate_datasets(train_datasets)
# eval_dataset = eval_dataset.map(intel_apply_chatml)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(

View File

@@ -12,6 +12,7 @@ from axolotl.cli import (
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
@@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
@@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC):
_train_dataset = None
_eval_dataset = None
_model_ref = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def model_ref(self):
return self._model_ref
@model_ref.setter
def model_ref(self, model):
self._model_ref = model
@property
def train_dataset(self):
return self._train_dataset
@@ -827,3 +837,96 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt",
**kwargs,
)
class HFDPOTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
"""
def get_callbacks(self):
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks
def build_training_arguments(self, total_num_steps):
training_args_kwargs = {}
for arg in [
"adam_beta1",
"adam_beta2",
"adam_epsilon",
"dataloader_num_workers",
"dataloader_pin_memory",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=total_num_steps,
remove_unused_columns=False,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
evaluation_strategy="no",
# eval_steps=self.cfg.eval_steps,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
bf16=True,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs={"use_reentrant": False},
logging_first_step=True,
logging_steps=1,
optim=self.cfg.optimizer,
save_total_limit=self.cfg.save_total_limit or 5,
**training_args_kwargs,
)
return training_args
def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
dpo_trainer = DPOTrainer(
self.model,
self.model_ref,
args=training_args,
beta=self.cfg.dpo_beta or 0.1,
train_dataset=self.train_dataset,
# eval_dataset=self.eval_dataset,
eval_dataset=None,
tokenizer=self.tokenizer,
max_length=self.cfg.sequence_len,
max_target_length=None,
max_prompt_length=self.cfg.sequence_len,
generate_during_eval=True,
**dpo_trainer_kwargs,
)
return dpo_trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
return callbacks
def build(self, total_num_steps):
# build PPOConfig
pass

View File

View File

@@ -0,0 +1,66 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer
class TRLPPOTrainer(PPOTrainer):
"""
wrapper for ppo trainer to handle customizations
"""
def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"]
# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards
# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)

View File

@@ -61,6 +61,12 @@ def train(
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model_ref = None
if cfg.rl:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
safe_serialization = cfg.save_safetensors is True
@@ -83,7 +89,7 @@ def train(
freeze_parameters_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
)
if hasattr(model, "config"):

View File

@@ -200,6 +200,7 @@ def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
inference: bool = False,
reference_model: bool = False,
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
@@ -290,6 +291,15 @@ def load_model(
model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
# if reference_model:
# model_kwargs["device_map"] = "cuda:" + str(
# torch.cuda.current_device() + 1
# )
# else:
# model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
@@ -560,9 +570,11 @@ def load_model(
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
lora_config = None
if not reference_model or cfg.lora_model_dir:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
model.to(f"cuda:{cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:

View File

@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler
@@ -280,7 +280,12 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
if cfg.rl:
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -0,0 +1,59 @@
"""
unit tests for axolotl.core.trainer_builder
"""
import pytest
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@pytest.fixture(name="cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00005,
"save_steps": 100,
"output_dir": "./model-out",
"warmup_steps": 10,
"gradient_checkpointing": False,
"optimizer": "adamw_torch",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
}
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(cfg):
return load_tokenizer(cfg)
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
class TestHFDPOTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True