Generalizing the chat_template prompt strategy (#1660) [skip ci]
The strategy now supports configuring several fields: * The data field holding message arrays * the role and content fields for each message * role mapping from source to target types additionally this adds a sample llama3-8b instruct template using the chat template
This commit is contained in:
76
examples/llama-3/instruct-lora-8b.yml
Normal file
76
examples/llama-3/instruct-lora-8b.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: llama3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,24 +1,55 @@
|
||||
"""
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
"""prompter for HF chat templates"""
|
||||
|
||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_field_role: str = "from",
|
||||
message_field_content: str = "value",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
else:
|
||||
self.roles = {
|
||||
"human": "user",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self.tokenizer = tokenizer
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||
turns = [
|
||||
{
|
||||
"role": self.roles[t[self.message_field_role]],
|
||||
"content": t[self.message_field_content],
|
||||
}
|
||||
for t in conversation
|
||||
]
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
turns,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
@@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "conversations"
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
||||
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
@@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {
|
||||
"human": "user",
|
||||
"user": "user",
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
turns = [
|
||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
return prompt[self.messages]
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
chat_template = (
|
||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||
)
|
||||
message_field_role = (
|
||||
ds_cfg["message_field_role"]
|
||||
if ds_cfg and "message_field_role" in ds_cfg
|
||||
else "from"
|
||||
)
|
||||
message_field_content = (
|
||||
ds_cfg["message_field_content"]
|
||||
if ds_cfg and "message_field_content" in ds_cfg
|
||||
else "value"
|
||||
)
|
||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_templates(chat_template),
|
||||
message_field_role=message_field_role,
|
||||
message_field_content=message_field_content,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
return strategy
|
||||
|
||||
@@ -110,6 +110,8 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[str] = None
|
||||
message_field_content: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
ChatTemplateStrategy,
|
||||
load,
|
||||
)
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "hello",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "goodbye",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "goodbye",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAssistantChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = load(
|
||||
llama3_tokenizer,
|
||||
DictDefault(
|
||||
{
|
||||
"train_on_inputs": False,
|
||||
"sequence_len": 512,
|
||||
}
|
||||
),
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
"field_messages": "messages",
|
||||
}
|
||||
),
|
||||
)
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_templates("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
},
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False,
|
||||
512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatTemplateLlama3:
|
||||
"""
|
||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||
|
||||
Reference in New Issue
Block a user