* orpo trainer

* rl handling for orpo

* support for remove_unused_columns

* orpo fixes

* fix loader for orpo

* chore: lint

* fix default for remove_unused_columns

* roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora

* better handling of system message for orpo

* revert system prompt changes for chat templtes

* no need for else condition

* split dataset parsing into it's own component
This commit is contained in:
Wing Lian
2024-03-18 13:10:00 -04:00
committed by GitHub
parent e8c8ea64b3
commit 2ea70ebbd8
14 changed files with 451 additions and 24 deletions

View File

@@ -8,7 +8,8 @@ from pathlib import Path
from typing import Optional
import pytest
from transformers import AutoTokenizer, LlamaTokenizer
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
Llama2ChatPrompter,
LLama2ChatTokenizingStrategy,
)
from axolotl.prompt_strategies.orpo.chat_template import load
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl")
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
)
class OrpoTokenizationTest(unittest.TestCase):
"""test case for the ORPO tokenization"""
def setUp(self) -> None:
# pylint: disable=duplicate-code
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken(
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
),
]
)
self.tokenizer = tokenizer
self.dataset = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
def test_orpo_integration(self):
strat = load(
self.tokenizer,
DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}),
)
res = strat.tokenize_prompt(self.dataset[0])
assert "rejected_input_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
assert "labels" in res
assert "prompt_attention_mask" in res
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
assert len(res["input_ids"]) == len(res["labels"])
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
assert res["rejected_labels"][0] == -100
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
assert res["labels"][0] == -100
assert res["input_ids"][-1] == res["labels"][-1]
assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0
if __name__ == "__main__":
unittest.main()