better handling and logging of empty sharegpt turns (#603)

This commit is contained in:
Wing Lian
2023-09-22 16:13:42 -04:00
committed by GitHub
parent 501958bb6f
commit a363604dcf
3 changed files with 105 additions and 14 deletions

View File

@@ -3,7 +3,9 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from transformers import AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
@@ -29,6 +31,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
@@ -64,6 +72,24 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
self.assertEqual(example[fields], tokenized_conversation[fields])
def test_sharegpt_warnings_integration(self):
with open(
Path(__file__).parent / "fixtures/conversation.missingturns.json",
encoding="utf-8",
) as fin:
data = fin.read()
conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message
def test_no_sys_prompt(self):
"""
tests the interface between the user and assistant parts