Fix(message): Improve error message for bad format (#365)

This commit is contained in:
NanoCode012
2023-08-13 01:16:18 +09:00
committed by GitHub
parent b5212068ac
commit e37d9358e6
2 changed files with 8 additions and 3 deletions

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass @dataclass
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801 conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2] assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]: if sentence["value"]:
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
yield conv yield conv

View File

@@ -260,6 +260,11 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter: # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
@@ -316,7 +321,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
assert role == conv.roles[j % 2] assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_prompt(): for part in conv.get_prompt():