DPO transformers v0.29 fixes (#3560) [skip ci]
* Deperecate dpo_norm_loss * Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179 * Remove deprecated rpo_alpha * Remove dead_code tokenize_row * Add _tokenize override to prevent double bos token on Llama DPO * Fix DPO loss type now list not string * Linting fix * PR fixes * update _tokenize override for DPO for multimodal
This commit is contained in:
@@ -223,18 +223,18 @@ class OrpoTokenizationTest:
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(ds[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert "prompt_attention_mask" in res
|
||||
|
||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["rejected_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["input_ids"]) == len(res["labels"])
|
||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||
|
||||
assert res["rejected_labels"][0] == -100
|
||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
||||
assert res["rejected_ids"][-1] == res["rejected_labels"][-1]
|
||||
|
||||
assert res["labels"][0] == -100
|
||||
assert res["input_ids"][-1] == res["labels"][-1]
|
||||
|
||||
Reference in New Issue
Block a user