better handling and logging of empty sharegpt turns (#603)
This commit is contained in:
@@ -3,7 +3,9 @@ import json
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
@@ -29,6 +31,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
@@ -64,6 +72,24 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def test_sharegpt_warnings_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.missingturns.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompter("chat")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
Reference in New Issue
Block a user