Fix(message): Improve error message for bad format (#365)
This commit is contained in:
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Generator, List, Sequence
|
from typing import Generator, List, Sequence
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -190,7 +190,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = [] # pylint: disable=R0801
|
conv.messages = [] # pylint: disable=R0801
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2]
|
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||||
if sentence["value"]:
|
if sentence["value"]:
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
yield conv
|
yield conv
|
||||||
|
|||||||
@@ -260,6 +260,11 @@ class Conversation:
|
|||||||
self.messages.append([role, message])
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
|
||||||
|
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||||
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
@@ -316,7 +321,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = []
|
conv.messages = []
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2]
|
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
for part in conv.get_prompt():
|
for part in conv.get_prompt():
|
||||||
|
|||||||
Reference in New Issue
Block a user