Generalizing the chat_template prompt strategy (#1660) [skip ci]

The strategy now supports configuring several fields: * The data field holding message arrays * the role and
content fields for each message * role mapping from source to target types

additionally this adds a sample llama3-8b instruct template using the chat template
This commit is contained in:
Keith Stevens
2024-05-29 00:24:13 +09:00
committed by GitHub
parent 5f91064040
commit cc11c6bce2
4 changed files with 258 additions and 17 deletions

View File

@@ -1,24 +1,55 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
import logging
from typing import Any, Dict, List, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
LOG = logging.getLogger("axolotl")
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
def __init__(
self,
tokenizer,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
roles: Optional[Dict[str, List[str]]] = None,
):
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
self.roles = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
}
for t in conversation
]
return self.tokenizer.apply_chat_template(
conversation,
turns,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
@@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
_messages = "conversations"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
@@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
return prompt[self.messages]
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -110,6 +110,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None