DPO transformers v0.29 fixes (#3560) [skip ci]

* Deperecate dpo_norm_loss

* Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179

* Remove deprecated rpo_alpha

* Remove dead_code tokenize_row

* Add _tokenize override to prevent double bos token on Llama DPO

* Fix DPO loss type now list not string

* Linting fix

* PR fixes

* update _tokenize override for DPO for multimodal
This commit is contained in:
Andrew Wu
2026-04-01 00:04:53 +01:00
committed by GitHub
parent bb622b83de
commit a81feabbd9
13 changed files with 100 additions and 126 deletions

View File

@@ -127,9 +127,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb: if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name training_args_kwargs["run_name"] = self.cfg.wandb_name

View File

@@ -405,15 +405,13 @@ class AxolotlTrainer(
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}
max_length = max( max_length = max(inputs["input_ids"].shape[1], inputs["rejected_ids"].shape[1])
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
)
# Concatenate positive and negative inputs # Concatenate positive and negative inputs
concatenated_batch["input_ids"] = pad_to_length( concatenated_batch["input_ids"] = pad_to_length(
inputs["input_ids"], max_length, pad_token inputs["input_ids"], max_length, pad_token
) )
concatenated_batch["rejected_input_ids"] = pad_to_length( concatenated_batch["rejected_ids"] = pad_to_length(
inputs["rejected_input_ids"], max_length, pad_token inputs["rejected_ids"], max_length, pad_token
) )
concatenated_batch["labels"] = pad_to_length( concatenated_batch["labels"] = pad_to_length(
inputs["labels"], max_length, label_pad_token inputs["labels"], max_length, label_pad_token
@@ -432,7 +430,7 @@ class AxolotlTrainer(
).to(device=device) ).to(device=device)
input_ids = torch.cat( input_ids = torch.cat(
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], [concatenated_batch["input_ids"], concatenated_batch["rejected_ids"]],
dim=0, dim=0,
).to(device=device) ).to(device=device)
attention_mask = torch.cat( attention_mask = torch.cat(

View File

@@ -21,7 +21,7 @@ class DPOStrategy:
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {} training_args_kwargs = {}
if cfg.rl is RLType.IPO: if cfg.rl is RLType.IPO:
training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["loss_type"] = ["ipo"]
# Label smoothing is not compatible with IPO # Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing: if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
@@ -30,8 +30,6 @@ class DPOStrategy:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None: if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_liger_kernel is not None: if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs return training_args_kwargs

View File

@@ -2,8 +2,7 @@
Axolotl specific DPO args Axolotl specific DPO args
""" """
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Optional
from trl import DPOConfig from trl import DPOConfig
@@ -15,6 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
""" """
DPO config for DPO training DPO config for DPO training
""" """
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.mixins import (
@@ -18,6 +19,7 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
) )
from axolotl.utils.data.utils import remove_double_bos_token
class AxolotlDPOTrainer( class AxolotlDPOTrainer(
@@ -53,36 +55,31 @@ class AxolotlDPOTrainer(
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@staticmethod def _tokenize(
def tokenize_row( self,
features, processing_class: PreTrainedTokenizerBase | ProcessorMixin,
processing_class, input: str | list,
max_prompt_length: int | None = None, **kwargs,
max_completion_length: int | None = None, ) -> dict[str, list]:
add_special_tokens: bool = True, """
is_chat: bool = False, Override TRL's tokenization in DPO trainer to fix double bos_token bug (eg. llama).
) -> Dict: """
res = DPOTrainer.tokenize_row( result = super()._tokenize(
features, processing_class=processing_class, input=input, **kwargs
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
) )
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None: # Handle multimodal models
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs tokenizer = (
if res["chosen_input_ids"][0] == processing_class.bos_token_id: getattr(processing_class, "tokenizer", None)
res["chosen_input_ids"] = res["chosen_input_ids"][1:] if isinstance(processing_class, ProcessorMixin)
if res["rejected_input_ids"][0] == processing_class.bos_token_id: else processing_class
res["rejected_input_ids"] = res["rejected_input_ids"][1:] )
return res bos_token_id = getattr(tokenizer, "bos_token_id", None) if tokenizer else None
if bos_token_id is not None:
result = remove_double_bos_token(result, bos_token_id)
return result
def training_step( def training_step(
self, self,
@@ -94,20 +91,3 @@ class AxolotlDPOTrainer(
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
def concatenated_forward(
self,
model: nn.Module,
batch: dict[str, Union[list, torch.LongTensor]],
is_ref_model: bool = False,
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: list[str] = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = ["ipo"]
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
] ]
return { return {
"chosen_input_ids": chosen_tokenized["input_ids"], "chosen_ids": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"], "attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0, "labels_chosen": 1.0,
"rejected_input_ids": rejected_tokenized["input_ids"], "rejected_ids": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"], "attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0, "labels_rejected": 0.0,
} }

View File

@@ -130,7 +130,7 @@ class ORPODatasetParsingStrategy:
class ORPOTokenizingStrategy(PromptTokenizingStrategy): class ORPOTokenizingStrategy(PromptTokenizingStrategy):
""" """
rejected_input_ids rejected_ids
input_ids input_ids
rejected_attention_mask rejected_attention_mask
attention_mask attention_mask
@@ -169,7 +169,7 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
prompt_len = len(input_ids) prompt_len = len(input_ids)
# remap the input_ids, attention_mask and labels # remap the input_ids, attention_mask and labels
rejected_input_ids = input_ids rejected_ids = input_ids
rejected_labels = labels rejected_labels = labels
# pass the chosen prompt/row to the Prompter to get the formatted prompt # pass the chosen prompt/row to the Prompter to get the formatted prompt
chosen_message_list: MessageList = ( chosen_message_list: MessageList = (
@@ -191,7 +191,7 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
return { return {
"rejected_input_ids": rejected_input_ids, "rejected_ids": rejected_ids,
"rejected_labels": rejected_labels, "rejected_labels": rejected_labels,
"rejected_attention_mask": [1] * len(rejected_labels), "rejected_attention_mask": [1] * len(rejected_labels),
"input_ids": input_ids, "input_ids": input_ids,

View File

@@ -349,3 +349,14 @@ def handle_long_seq_in_dataset(
) )
return dataset return dataset
def remove_double_bos_token(example: dict[str, list], bos_token_id: int):
"""Remove double bos tokens that may occur when retokenizing preprocessed data
for tokenizers and chat templates that have a bos_token - eg. DPO + Llama.
"""
input_ids = example["input_ids"]
if len(input_ids) >= 2 and input_ids[0] == input_ids[1] == bos_token_id:
for key in example:
example[key] = example[key][1:]
return example

View File

@@ -294,7 +294,6 @@ class AxolotlInputConfig(
}, },
) )
dpo_label_smoothing: float | None = None dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_use_liger_kernel: bool | None = Field( dpo_use_liger_kernel: bool | None = Field(
default=None, default=None,
@@ -1111,12 +1110,6 @@ class AxolotlInputConfig(
"description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping." "description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping."
}, },
) )
rpo_alpha: float | None = Field(
default=None,
json_schema_extra={
"description": "Weighting of NLL term in loss from RPO paper"
},
)
simpo_gamma: float | None = Field( simpo_gamma: float | None = Field(
default=None, default=None,
json_schema_extra={"description": "Target reward margin for the SimPO loss"}, json_schema_extra={"description": "Target reward margin for the SimPO loss"},

View File

@@ -21,6 +21,8 @@ class DeprecatedParameters(BaseModel):
eval_max_new_tokens: int | None = None eval_max_new_tokens: int | None = None
dpo_use_logits_to_keep: bool | None = None dpo_use_logits_to_keep: bool | None = None
dpo_generate_during_eval: bool | None = None dpo_generate_during_eval: bool | None = None
dpo_norm_loss: bool | None = None
rpo_alpha: float | None = None
@field_validator("max_packed_sequence_len") @field_validator("max_packed_sequence_len")
@classmethod @classmethod
@@ -100,6 +102,26 @@ class DeprecatedParameters(BaseModel):
) )
return dpo_generate_during_eval return dpo_generate_during_eval
@field_validator("dpo_norm_loss")
@classmethod
def validate_dpo_norm_loss(cls, dpo_norm_loss):
if dpo_norm_loss is not None:
raise DeprecationWarning(
"`dpo_norm_loss` is no longer supported, "
"due to breaking changes in TRL >= 0.29.0"
)
return dpo_norm_loss
@field_validator("rpo_alpha")
@classmethod
def validate_rpo_alpha(cls, rpo_alpha):
if rpo_alpha is not None:
raise DeprecationWarning(
"`rpo_alpha` has been deprecated in TRL >= 0.29.0, "
"and now requires passing multiple loss types, which is not yet supported by Axolotl."
) # TODO: change this warning once multiple dpo loss types are supported.
return rpo_alpha
class RemappedParameters(BaseModel): class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names""" """Parameters that have been remapped to other names"""

View File

@@ -67,55 +67,6 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@with_temp_dir @with_temp_dir
def test_dpo_use_weighting(self, temp_dir): def test_dpo_use_weighting(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(

View File

@@ -223,18 +223,18 @@ class OrpoTokenizationTest:
DictDefault({"chat_template": "chatml"}), DictDefault({"chat_template": "chatml"}),
) )
res = strat.tokenize_prompt(ds[0]) res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res assert "rejected_ids" in res
assert "rejected_labels" in res assert "rejected_labels" in res
assert "input_ids" in res assert "input_ids" in res
assert "labels" in res assert "labels" in res
assert "prompt_attention_mask" in res assert "prompt_attention_mask" in res
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"]) assert len(res["rejected_ids"]) == len(res["rejected_labels"])
assert len(res["input_ids"]) == len(res["labels"]) assert len(res["input_ids"]) == len(res["labels"])
assert len(res["input_ids"]) == len(res["prompt_attention_mask"]) assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
assert res["rejected_labels"][0] == -100 assert res["rejected_labels"][0] == -100
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1] assert res["rejected_ids"][-1] == res["rejected_labels"][-1]
assert res["labels"][0] == -100 assert res["labels"][0] == -100
assert res["input_ids"][-1] == res["labels"][-1] assert res["input_ids"][-1] == res["labels"][-1]

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
from datasets import Dataset from datasets import Dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -541,5 +541,33 @@ class TestHandleLongSeqInDataset(unittest.TestCase):
self.assertEqual(len(result[0]["input_ids"]), 3) self.assertEqual(len(result[0]["input_ids"]), 3)
class TestRemoveDoubleBOSToken(unittest.TestCase):
def test_no_remove_bos_token(self):
input_ids = [0, 1, 2]
labels = [1, 2, 3]
example = {
"input_ids": input_ids,
"labels": labels,
}
example = remove_double_bos_token(example, 0)
assert example["input_ids"] == input_ids
assert example["labels"] == labels
def test_remove_bos_token(self):
input_ids = [0, 0, 1]
labels = [0, 1, 2]
example = {
"input_ids": input_ids,
"labels": labels,
}
example = remove_double_bos_token(example, 0)
assert example["input_ids"] == [0, 1]
assert example["labels"] == [1, 2]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()