DPO transformers v0.29 fixes (#3560) [skip ci]

* Deperecate dpo_norm_loss

* Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179

* Remove deprecated rpo_alpha

* Remove dead_code tokenize_row

* Add _tokenize override to prevent double bos token on Llama DPO

* Fix DPO loss type now list not string

* Linting fix

* PR fixes

* update _tokenize override for DPO for multimodal
This commit is contained in:
Andrew Wu
2026-04-01 00:04:53 +01:00
committed by GitHub
parent bb622b83de
commit a81feabbd9
13 changed files with 100 additions and 126 deletions

View File

@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token
from axolotl.utils.dict import DictDefault
@@ -541,5 +541,33 @@ class TestHandleLongSeqInDataset(unittest.TestCase):
self.assertEqual(len(result[0]["input_ids"]), 3)
class TestRemoveDoubleBOSToken(unittest.TestCase):
def test_no_remove_bos_token(self):
input_ids = [0, 1, 2]
labels = [1, 2, 3]
example = {
"input_ids": input_ids,
"labels": labels,
}
example = remove_double_bos_token(example, 0)
assert example["input_ids"] == input_ids
assert example["labels"] == labels
def test_remove_bos_token(self):
input_ids = [0, 0, 1]
labels = [0, 1, 2]
example = {
"input_ids": input_ids,
"labels": labels,
}
example = remove_double_bos_token(example, 0)
assert example["input_ids"] == [0, 1]
assert example["labels"] == [1, 2]
if __name__ == "__main__":
unittest.main()