ORPO Trainer replacement (#1551)
* WIP use trl ORPOTrainer * fixes to make orpo work with trl * fix the chat template laoding * make sure to handle the special tokens and add_generation for assistant turn too
This commit is contained in:
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
trl==0.8.5
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HFDPOTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Trainer factory class for DPO Trainer
|
Trainer factory class for DPO Trainer
|
||||||
"""
|
"""
|
||||||
@@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args = TrainingArguments(
|
if self.cfg.orpo_alpha:
|
||||||
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
training_args_cls = TrainingArguments
|
||||||
|
if self.cfg.rl == "orpo":
|
||||||
|
training_args_cls = ORPOConfig
|
||||||
|
|
||||||
|
training_args = training_args_cls(
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
max_steps=self.cfg.max_steps or total_num_steps,
|
max_steps=self.cfg.max_steps or total_num_steps,
|
||||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||||
@@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
dpo_trainer = AxolotlDPOTrainer(
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
||||||
self.model,
|
trainer_cls = AxolotlDPOTrainer
|
||||||
self.model_ref,
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
|
# these aren't used for the ORPO trainer
|
||||||
|
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
|
elif self.cfg.rl == "orpo":
|
||||||
|
trainer_cls = AxolotlORPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
dpo_trainer = trainer_cls(
|
||||||
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
beta=self.cfg.dpo_beta or 0.1,
|
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_length=self.cfg.sequence_len,
|
|
||||||
max_target_length=None,
|
|
||||||
max_prompt_length=self.cfg.sequence_len,
|
|
||||||
generate_during_eval=True,
|
|
||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**dpo_trainer_kwargs,
|
**dpo_trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ from functools import partial
|
|||||||
|
|
||||||
from ..base import load as load_base
|
from ..base import load as load_base
|
||||||
|
|
||||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
||||||
|
|||||||
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
|
|||||||
)
|
)
|
||||||
return MessageList(messages=messages)
|
return MessageList(messages=messages)
|
||||||
|
|
||||||
|
def get_prompt(self, prompt) -> MessageList:
|
||||||
|
"""Map the data to extract everything up to the last turn"""
|
||||||
|
total_msg_len = len(prompt["chosen"])
|
||||||
|
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
||||||
|
assert remainder == 0, "invalid number of turns"
|
||||||
|
|
||||||
|
messages: List[Message] = []
|
||||||
|
if system := prompt.get("system", None):
|
||||||
|
messages.append(Message(role="system", content=system, label=False))
|
||||||
|
for i in range(total_msg_turns):
|
||||||
|
if "prompt" in prompt:
|
||||||
|
messages.append(
|
||||||
|
Message(role="user", content=prompt["prompt"], label=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append(
|
||||||
|
Message(
|
||||||
|
role="user",
|
||||||
|
content=prompt["chosen"][i * 2]["content"],
|
||||||
|
label=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if i < total_msg_turns - 1:
|
||||||
|
messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant",
|
||||||
|
content=prompt["chosen"][i * 2 + 1]["content"],
|
||||||
|
label=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return MessageList(messages=messages)
|
||||||
|
|
||||||
|
def get_chosen(self, prompt) -> MessageList:
|
||||||
|
res = self.get_prompt(prompt)
|
||||||
|
res.messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_rejected(self, prompt) -> MessageList:
|
||||||
|
res = self.get_prompt(prompt)
|
||||||
|
res.messages.append(
|
||||||
|
Message(
|
||||||
|
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
), True
|
), True
|
||||||
|
|
||||||
|
|
||||||
|
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
dataset_parser = ORPODatasetParsingStrategy()
|
||||||
|
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
res["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
prompt_str_len = len(res["prompt"])
|
||||||
|
res["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)[prompt_str_len:]
|
||||||
|
res["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)[prompt_str_len:]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Data processing modules
|
Data processing modules
|
||||||
"""
|
"""
|
||||||
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
|
||||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||||
from axolotl.utils.data.sft import ( # noqa: F401
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
load_prepare_datasets,
|
load_prepare_datasets,
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
)
|
)
|
||||||
split_datasets.insert(i, ds)
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
|
tokenizer = None
|
||||||
for i, data_set in enumerate(split_datasets):
|
for i, data_set in enumerate(split_datasets):
|
||||||
_type = dataset_cfgs[i]["type"]
|
_type = dataset_cfgs[i]["type"]
|
||||||
if _type:
|
if _type:
|
||||||
if isinstance(_type, DictDefault):
|
if isinstance(_type, DictDefault):
|
||||||
_type = "user_defined.default"
|
_type = "user_defined.default"
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
if _cfg.rl == "orpo":
|
||||||
split_datasets[i] = data_set.map(
|
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||||
|
else:
|
||||||
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
sig = inspect.signature(ds_transform_fn)
|
||||||
|
if "tokenizer" in sig.parameters:
|
||||||
|
if not tokenizer:
|
||||||
|
tokenizer = load_tokenizer(_cfg)
|
||||||
|
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
)
|
)
|
||||||
|
if isinstance(data_set, DatasetDict):
|
||||||
|
data_set = data_set["train"]
|
||||||
|
split_datasets[i] = data_set
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
|
|||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
||||||
trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
|||||||
return load_model(cfg, tokenizer)
|
return load_model(cfg, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
class TestHFDPOTrainerBuilder:
|
class TestHFRLTrainerBuilder:
|
||||||
"""
|
"""
|
||||||
TestCase class for DPO trainer builder
|
TestCase class for DPO trainer builder
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.998
|
assert training_arguments.adam_beta1 == 0.998
|
||||||
assert training_arguments.adam_beta2 == 0.9
|
assert training_arguments.adam_beta2 == 0.9
|
||||||
|
|||||||
Reference in New Issue
Block a user