ORPO Trainer replacement (#1551)
* WIP use trl ORPOTrainer * fixes to make orpo work with trl * fix the chat template laoding * make sure to handle the special tokens and add_generation for assistant turn too
This commit is contained in:
@@ -39,6 +39,6 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||
trl==0.8.5
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
@@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
return res
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
|
||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for DPO Trainer
|
||||
"""
|
||||
@@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_cls = TrainingArguments
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = ORPOConfig
|
||||
|
||||
training_args = training_args_cls(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
@@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
dpo_trainer = AxolotlDPOTrainer(
|
||||
self.model,
|
||||
self.model_ref,
|
||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
beta=self.cfg.dpo_beta or 0.1,
|
||||
train_dataset=self.train_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.cfg.sequence_len,
|
||||
max_target_length=None,
|
||||
max_prompt_length=self.cfg.sequence_len,
|
||||
generate_during_eval=True,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
|
||||
@@ -6,4 +6,4 @@ from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
||||
|
||||
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_prompt(self, prompt) -> MessageList:
|
||||
"""Map the data to extract everything up to the last turn"""
|
||||
total_msg_len = len(prompt["chosen"])
|
||||
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
||||
assert remainder == 0, "invalid number of turns"
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
for i in range(total_msg_turns):
|
||||
if "prompt" in prompt:
|
||||
messages.append(
|
||||
Message(role="user", content=prompt["prompt"], label=False)
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=prompt["chosen"][i * 2]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
if i < total_msg_turns - 1:
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=prompt["chosen"][i * 2 + 1]["content"],
|
||||
label=False,
|
||||
)
|
||||
)
|
||||
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_chosen(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
def get_rejected(self, prompt) -> MessageList:
|
||||
res = self.get_prompt(prompt)
|
||||
res.messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), True
|
||||
|
||||
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
return res
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Data processing modules
|
||||
"""
|
||||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
"""data handling specific to DPO"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = None
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
split_datasets[i] = data_set.map(
|
||||
if _cfg.rl == "orpo":
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
tokenizer = load_tokenizer(_cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
)
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
split_datasets[i] = data_set
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
else:
|
||||
|
||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
|
||||
|
||||
class TestHFDPOTrainerBuilder:
|
||||
class TestHFRLTrainerBuilder:
|
||||
"""
|
||||
TestCase class for DPO trainer builder
|
||||
"""
|
||||
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
|
||||
Reference in New Issue
Block a user