Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
This commit is contained in:
@@ -6,8 +6,12 @@ from pathlib import Path
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompter
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
|
||||
|
||||
logging.basicConfig(level="INFO")
|
||||
|
||||
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_sharegpt_integration(self):
|
||||
print(Path(__file__).parent)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
prompter = NoSystemPrompter()
|
||||
# pylint: disable=duplicate-code
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {
|
||||
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
||||
"output": "world!",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(3186)
|
||||
assert example["labels"][world_idx] == 3186
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
def test_alpaca(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
world_idx = example["input_ids"].index(6324)
|
||||
assert example["labels"][world_idx] == 6324
|
||||
assert example["labels"][world_idx - 1] == -100
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user