diff --git a/docs/rlhf.md b/docs/rlhf.md new file mode 100644 index 000000000..371a40dbf --- /dev/null +++ b/docs/rlhf.md @@ -0,0 +1,35 @@ +# RLHF (Beta) + +### Overview + +Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human +feedback. Various methods include, but not limited to: + +- Proximal Policy Optimization (PPO) (not yet supported in axolotl) +- Direct Preference Optimization (DPO) +- Identity Preference Optimization (IPO) + + +### RLHF using Axolotl + +[!IMPORTANT] +This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. + +The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML + +#### DPO +```yaml +rl: true +datasets: + - path: Intel/orca_dpo_pairs + split: train + type: intel_apply_chatml + - path: argilla/ultrafeedback-binarized-preferences + split: train + type: argilla_apply_chatml +``` + +#### IPO +```yaml +rl: ipo +``` diff --git a/requirements.txt b/requirements.txt index f4df0dd67..14f6633f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,5 @@ tensorboard s3fs gcsfs # adlfs + +trl @ git+https://github.com/huggingface/trl.git@main diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 85f6b358a..0477ebebf 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -2,6 +2,7 @@ import importlib import logging +import math import os import random import sys @@ -16,6 +17,7 @@ import yaml # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art +from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer @@ -325,6 +327,94 @@ def load_datasets( ) +def load_rl_datasets( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, # pylint: disable=unused-argument +) -> TrainDatasetMeta: + train_datasets: List[Any] = [] + for i, ds_cfg in enumerate(cfg.datasets): + train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"])) + # eval_dataset = load_dataset( + # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"] + # ) + eval_dataset = None + + def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" + sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + return sample + + def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>" + sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>" + return sample + + for i, data_set in enumerate(train_datasets): + _type = cfg.datasets[i]["type"] + ds_type_fn = locals()[_type] + train_datasets[i] = data_set.map(ds_type_fn) + train_dataset = concatenate_datasets(train_datasets) + + # eval_dataset = eval_dataset.map(intel_apply_chatml) + + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + + return TrainDatasetMeta( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + total_num_steps=total_num_steps, + ) + + def check_accelerate_default_config(): if Path(config_args.default_yaml_config_file).exists(): LOG.warning( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 81307b6b9..2248784df 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -12,6 +12,7 @@ from axolotl.cli import ( check_user_token, load_cfg, load_datasets, + load_rl_datasets, print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs @@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cfg.rl: + dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4ca2877d1..1ca36eb41 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_utils import seed_worker +from trl import DPOTrainer from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( @@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC): _train_dataset = None _eval_dataset = None + _model_ref = None def __init__(self, cfg, model, tokenizer): self.cfg = cfg self.model = model self.tokenizer = tokenizer + @property + def model_ref(self): + return self._model_ref + + @model_ref.setter + def model_ref(self, model): + self._model_ref = model + @property def train_dataset(self): return self._train_dataset @@ -827,3 +837,96 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return_tensors="pt", **kwargs, ) + + +class HFDPOTrainerBuilder(TrainerBuilderBase): + """ + Trainer factory class for DPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build_training_arguments(self, total_num_steps): + training_args_kwargs = {} + for arg in [ + "adam_beta1", + "adam_beta2", + "adam_epsilon", + "dataloader_num_workers", + "dataloader_pin_memory", + ]: + if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: + training_args_kwargs[arg] = getattr(self.cfg, arg) + training_args = TrainingArguments( + per_device_train_batch_size=self.cfg.micro_batch_size, + max_steps=total_num_steps, + remove_unused_columns=False, + gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, + learning_rate=self.cfg.learning_rate, + evaluation_strategy="no", + # eval_steps=self.cfg.eval_steps, + save_strategy="steps", + save_steps=self.cfg.save_steps, + output_dir=self.cfg.output_dir, + warmup_steps=self.cfg.warmup_steps, + bf16=True, + gradient_checkpointing=self.cfg.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + logging_first_step=True, + logging_steps=1, + optim=self.cfg.optimizer, + save_total_limit=self.cfg.save_total_limit or 5, + **training_args_kwargs, + ) + + return training_args + + def build(self, total_num_steps): + training_args = self.build_training_arguments(total_num_steps) + dpo_trainer_kwargs = {} + if self.cfg.rl == "ipo": + dpo_trainer_kwargs["loss_type"] = "ipo" + if self.cfg.dpo_label_smoothing: + dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + + dpo_trainer = DPOTrainer( + self.model, + self.model_ref, + args=training_args, + beta=self.cfg.dpo_beta or 0.1, + train_dataset=self.train_dataset, + # eval_dataset=self.eval_dataset, + eval_dataset=None, + tokenizer=self.tokenizer, + max_length=self.cfg.sequence_len, + max_target_length=None, + max_prompt_length=self.cfg.sequence_len, + generate_during_eval=True, + **dpo_trainer_kwargs, + ) + + return dpo_trainer + + +class HFPPOTrainerBuilder(TrainerBuilderBase): + """ + HF Factory class for PPO Trainer + """ + + def get_callbacks(self): + callbacks = [] + return callbacks + + def get_post_trainer_create_callbacks(self, trainer): + callbacks = [] + return callbacks + + def build(self, total_num_steps): + # build PPOConfig + pass diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py new file mode 100644 index 000000000..24c0b0412 --- /dev/null +++ b/src/axolotl/core/trainers/trl.py @@ -0,0 +1,66 @@ +""" +module for TRL PPO training +""" +import torch +from tqdm import tqdm +from trl import PPOTrainer + + +class TRLPPOTrainer(PPOTrainer): + """ + wrapper for ppo trainer to handle customizations + """ + + def train( + self, + reward_pipe, + resume_from_checkpoint=None, # pylint: disable=unused-argument + ): + generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.tokenizer.eos_token_id, + "max_new_tokens": 32, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + } + + for epoch, batch in tqdm( # pylint: disable=unused-variable + enumerate(self.dataloader) + ): + query_tensors = batch["input_ids"] + + # generate model response + response_tensors, ref_response_tensors = self.generate( + query_tensors, + return_prompt=False, + generate_ref_response=True, + **generation_kwargs + ) + batch["response"] = self.tokenizer.batch_decode(response_tensors) + batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = reward_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs] + ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])] + ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs) + ref_rewards = [ + torch.tensor(output[1]["score"]) for output in ref_pipe_outputs + ] + batch["ref_rewards"] = ref_rewards + + # Run PPO step + stats = self.step(query_tensors, response_tensors, rewards) + self.log_stats( + stats, + batch, + rewards, + columns_to_log=["query", "response", "ref_response", "ref_rewards"], + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 4e5241e4c..e0da11252 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -61,6 +61,12 @@ def train( msg += " and peft_config..." LOG.debug(msg) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) + model_ref = None + if cfg.rl: + # load the model again for model_ref/baseline + model_ref, _ = load_model( + cfg, tokenizer, inference=cli_args.inference, reference_model=True + ) safe_serialization = cfg.save_safetensors is True @@ -83,7 +89,7 @@ def train( freeze_parameters_except(model, cfg.unfrozen_parameters) trainer = setup_trainer( - cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps + cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps ) if hasattr(model, "config"): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb2420108..b30ffcad8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -200,6 +200,7 @@ def load_model( cfg: DictDefault, tokenizer: PreTrainedTokenizerBase, inference: bool = False, + reference_model: bool = False, ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: """ Load a model for a given configuration and tokenizer. @@ -290,6 +291,15 @@ def load_model( model_kwargs["device_map"] = cfg.device_map model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype + # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) if is_deepspeed_zero3_enabled(): del model_kwargs["device_map"] @@ -560,9 +570,11 @@ def load_model( if hasattr(module, "weight"): module.to(cfg.torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) + lora_config = None + if not reference_model or cfg.lora_model_dir: + model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.ddp and not load_in_8bit: + if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): model.to(f"cuda:{cfg.local_rank}") if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f046dd7be..d975bb9a2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -12,7 +12,7 @@ from accelerate.logging import get_logger from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler -from axolotl.core.trainer_builder import HFCausalTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler @@ -280,7 +280,12 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) + if cfg.rl: + trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.model_ref = model[1] + else: + trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer) + trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py new file mode 100644 index 000000000..e8987ef45 --- /dev/null +++ b/tests/core/test_trainer_builder.py @@ -0,0 +1,59 @@ +""" +unit tests for axolotl.core.trainer_builder +""" +import pytest + +from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + + +@pytest.fixture(name="cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "LlamaTokenizer", + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.00005, + "save_steps": 100, + "output_dir": "./model-out", + "warmup_steps": 10, + "gradient_checkpointing": False, + "optimizer": "adamw_torch", + "sequence_len": 2048, + "rl": True, + "adam_beta1": 0.998, + "adam_beta2": 0.9, + "adam_epsilon": 0.00001, + "dataloader_num_workers": 1, + "dataloader_pin_memory": True, + } + ) + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(cfg): + return load_tokenizer(cfg) + + +@pytest.fixture(name="model") +def fixture_model(cfg, tokenizer): + return load_model(cfg, tokenizer) + + +class TestHFDPOTrainerBuilder: + """ + TestCase class for DPO trainer builder + """ + + def test_build_training_arguments(self, cfg, model, tokenizer): + builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + training_arguments = builder.build_training_arguments(100) + assert training_arguments.adam_beta1 == 0.998 + assert training_arguments.adam_beta2 == 0.9 + assert training_arguments.adam_epsilon == 0.00001 + assert training_arguments.dataloader_num_workers == 1 + assert training_arguments.dataloader_pin_memory is True