From 7d1d22f72f53af613cc168f3d473e174f0b79b47 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 19 Apr 2024 17:25:36 -0400 Subject: [PATCH] ORPO Trainer replacement (#1551) * WIP use trl ORPOTrainer * fixes to make orpo work with trl * fix the chat template laoding * make sure to handle the special tokens and add_generation for assistant turn too --- requirements.txt | 2 +- src/axolotl/cli/preprocess.py | 2 +- src/axolotl/cli/train.py | 2 +- src/axolotl/core/trainer_builder.py | 47 ++++++++--- .../prompt_strategies/orpo/__init__.py | 2 +- .../prompt_strategies/orpo/chat_template.py | 84 +++++++++++++++++++ src/axolotl/utils/data/__init__.py | 2 +- src/axolotl/utils/data/{dpo.py => rl.py} | 24 +++++- src/axolotl/utils/trainer.py | 6 +- tests/core/test_trainer_builder.py | 6 +- 10 files changed, 151 insertions(+), 26 deletions(-) rename src/axolotl/utils/data/{dpo.py => rl.py} (80%) diff --git a/requirements.txt b/requirements.txt index 4d74dee90..9289a40f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f +trl==0.8.5 zstandard==0.22.0 fastcore diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index a1a01d59d..fa71d6793 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl and parsed_cfg.rl != "orpo": + if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7e004567a..0cebe5a52 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() - if cfg.rl and cfg.rl != "orpo": + if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 900dcb788..fdb081003 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ from transformers import ( ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer +from trl import DPOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer): return res +class AxolotlORPOTrainer(ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) -class HFDPOTrainerBuilder(TrainerBuilderBase): +class HFRLTrainerBuilder(TrainerBuilderBase): """ Trainer factory class for DPO Trainer """ @@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" - training_args = TrainingArguments( + if self.cfg.orpo_alpha: + # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? + training_args_kwargs["beta"] = self.cfg.orpo_alpha + + training_args_cls = TrainingArguments + if self.cfg.rl == "orpo": + training_args_cls = ORPOConfig + + training_args = training_args_cls( per_device_train_batch_size=self.cfg.micro_batch_size, max_steps=self.cfg.max_steps or total_num_steps, gradient_accumulation_steps=self.cfg.gradient_accumulation_steps, @@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - dpo_trainer = AxolotlDPOTrainer( - self.model, - self.model_ref, + if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: + trainer_cls = AxolotlDPOTrainer + dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 + trainer_cls_args = [self.model, self.model_ref] + + # these aren't used for the ORPO trainer + dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["max_target_length"] = None + dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len + dpo_trainer_kwargs["generate_during_eval"] = True + elif self.cfg.rl == "orpo": + trainer_cls = AxolotlORPOTrainer + trainer_cls_args = [self.model] + else: + raise ValueError(f"Unsupported RL: {self.cfg.rl}") + dpo_trainer = trainer_cls( + *trainer_cls_args, args=training_args, - beta=self.cfg.dpo_beta or 0.1, train_dataset=self.train_dataset, tokenizer=self.tokenizer, - max_length=self.cfg.sequence_len, - max_target_length=None, - max_prompt_length=self.cfg.sequence_len, - generate_during_eval=True, callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py index 3a961fcc9..4a02f3c62 100644 --- a/src/axolotl/prompt_strategies/orpo/__init__.py +++ b/src/axolotl/prompt_strategies/orpo/__init__.py @@ -6,4 +6,4 @@ from functools import partial from ..base import load as load_base -load = partial(load_base, module="axolotl.prompt_strategies.orpo") +load = partial(load_base, module_base="axolotl.prompt_strategies.orpo") diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index 9953fe87e..a89dee157 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy: ) return MessageList(messages=messages) + def get_prompt(self, prompt) -> MessageList: + """Map the data to extract everything up to the last turn""" + total_msg_len = len(prompt["chosen"]) + total_msg_turns, remainder = divmod(total_msg_len, 2) + assert remainder == 0, "invalid number of turns" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + for i in range(total_msg_turns): + if "prompt" in prompt: + messages.append( + Message(role="user", content=prompt["prompt"], label=False) + ) + else: + messages.append( + Message( + role="user", + content=prompt["chosen"][i * 2]["content"], + label=False, + ) + ) + if i < total_msg_turns - 1: + messages.append( + Message( + role="assistant", + content=prompt["chosen"][i * 2 + 1]["content"], + label=False, + ) + ) + + return MessageList(messages=messages) + + def get_chosen(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["chosen"][-1]["content"], label=True + ) + ) + return res + + def get_rejected(self, prompt) -> MessageList: + res = self.get_prompt(prompt) + res.messages.append( + Message( + role="assistant", content=prompt["rejected"][-1]["content"], label=True + ) + ) + return res + class ORPOTokenizingStrategy(PromptTokenizingStrategy): """ @@ -186,3 +237,36 @@ class ORPOPrompter(Prompter): chat_template=self.chat_template, tokenize=False, ), True + + +def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument + dataset_parser = ORPODatasetParsingStrategy() + + chat_template_str = chat_templates(cfg.chat_template) + + def transform_fn(sample, tokenizer=None): + res = {} + + res["prompt"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], + add_generation_prompt=True, + chat_template=chat_template_str, + tokenize=False, + ) + prompt_str_len = len(res["prompt"]) + res["chosen"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + res["rejected"] = tokenizer.apply_chat_template( + [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], + add_generation_prompt=False, + chat_template=chat_template_str, + tokenize=False, + )[prompt_str_len:] + + return res + + return transform_fn diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index 1015c8370..140d02106 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,11 +1,11 @@ """ Data processing modules """ -from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.pretraining import ( # noqa: F401 encode_pretraining, wrap_pretraining_dataset, ) +from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401 get_dataset_wrapper, load_prepare_datasets, diff --git a/src/axolotl/utils/data/dpo.py b/src/axolotl/utils/data/rl.py similarity index 80% rename from src/axolotl/utils/data/dpo.py rename to src/axolotl/utils/data/rl.py index 765a3fc37..ff5ca87dd 100644 --- a/src/axolotl/utils/data/dpo.py +++ b/src/axolotl/utils/data/rl.py @@ -1,17 +1,20 @@ """data handling specific to DPO""" - +import inspect import logging +from functools import partial from pathlib import Path from typing import Any, List import yaml -from datasets import concatenate_datasets, load_dataset, load_from_disk +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.models import load_tokenizer LOG = logging.getLogger("axolotl") @@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg): ) split_datasets.insert(i, ds) + tokenizer = None for i, data_set in enumerate(split_datasets): _type = dataset_cfgs[i]["type"] if _type: if isinstance(_type, DictDefault): _type = "user_defined.default" - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - split_datasets[i] = data_set.map( + if _cfg.rl == "orpo": + ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) + else: + ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + sig = inspect.signature(ds_transform_fn) + if "tokenizer" in sig.parameters: + if not tokenizer: + tokenizer = load_tokenizer(_cfg) + ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer) + + data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", ) + if isinstance(data_set, DatasetDict): + data_set = data_set["train"] + split_datasets[i] = data_set else: # If no `type` is provided, assume the dataset is already in the expected format with # "prompt", "chosen" and "rejected" already preprocessed diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2a8ed216d..808fbb59f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,7 +13,7 @@ from datasets import set_caching_enabled from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -340,8 +340,8 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair"]: - trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer) + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2] else: diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 541fdb343..82455922e 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder import pytest -from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer): return load_model(cfg, tokenizer) -class TestHFDPOTrainerBuilder: +class TestHFRLTrainerBuilder: """ TestCase class for DPO trainer builder """ def test_build_training_arguments(self, cfg, model, tokenizer): - builder = HFDPOTrainerBuilder(cfg, model, tokenizer) + builder = HFRLTrainerBuilder(cfg, model, tokenizer) training_arguments = builder.build_training_arguments(100) assert training_arguments.adam_beta1 == 0.998 assert training_arguments.adam_beta2 == 0.9