ORPO (#1419)
* orpo trainer * rl handling for orpo * support for remove_unused_columns * orpo fixes * fix loader for orpo * chore: lint * fix default for remove_unused_columns * roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora * better handling of system message for orpo * revert system prompt changes for chat templtes * no need for else condition * split dataset parsing into it's own component
This commit is contained in:
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
)
|
||||
|
||||
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
def test_orpo_integration(self):
|
||||
strat = load(
|
||||
self.tokenizer,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert "prompt_attention_mask" in res
|
||||
|
||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["input_ids"]) == len(res["labels"])
|
||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||
|
||||
assert res["rejected_labels"][0] == -100
|
||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
||||
|
||||
assert res["labels"][0] == -100
|
||||
assert res["input_ids"][-1] == res["labels"][-1]
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user