DPO transformers v0.29 fixes (#3560) [skip ci]
* Deperecate dpo_norm_loss * Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179 * Remove deprecated rpo_alpha * Remove dead_code tokenize_row * Add _tokenize override to prevent double bos token on Llama DPO * Fix DPO loss type now list not string * Linting fix * PR fixes * update _tokenize override for DPO for multimodal
This commit is contained in:
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -541,5 +541,33 @@ class TestHandleLongSeqInDataset(unittest.TestCase):
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
|
||||
|
||||
class TestRemoveDoubleBOSToken(unittest.TestCase):
|
||||
def test_no_remove_bos_token(self):
|
||||
input_ids = [0, 1, 2]
|
||||
labels = [1, 2, 3]
|
||||
|
||||
example = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
example = remove_double_bos_token(example, 0)
|
||||
assert example["input_ids"] == input_ids
|
||||
assert example["labels"] == labels
|
||||
|
||||
def test_remove_bos_token(self):
|
||||
input_ids = [0, 0, 1]
|
||||
labels = [0, 1, 2]
|
||||
|
||||
example = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
example = remove_double_bos_token(example, 0)
|
||||
assert example["input_ids"] == [0, 1]
|
||||
assert example["labels"] == [1, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user