Generalizing the chat_template prompt strategy (#1660) [skip ci]
The strategy now supports configuring several fields: * The data field holding message arrays * the role and content fields for each message * role mapping from source to target types additionally this adds a sample llama3-8b instruct template using the chat template
This commit is contained in:
76
examples/llama-3/instruct-lora-8b.yml
Normal file
76
examples/llama-3/instruct-lora-8b.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,24 +1,55 @@
|
|||||||
"""
|
"""
|
||||||
HF Chat Templates prompt strategy
|
HF Chat Templates prompt strategy
|
||||||
"""
|
"""
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(self, tokenizer, chat_template=None, max_length=2048):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
max_length=2048,
|
||||||
|
message_field_role: str = "from",
|
||||||
|
message_field_content: str = "value",
|
||||||
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
|
):
|
||||||
|
if roles:
|
||||||
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
|
else:
|
||||||
|
self.roles = {
|
||||||
|
"human": "user",
|
||||||
|
"user": "user",
|
||||||
|
"assistant": "assistant",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
self.message_field_role = message_field_role
|
||||||
|
self.message_field_content = message_field_content
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False):
|
def build_prompt(self, conversation, add_generation_prompt=False):
|
||||||
|
turns = [
|
||||||
|
{
|
||||||
|
"role": self.roles[t[self.message_field_role]],
|
||||||
|
"content": t[self.message_field_content],
|
||||||
|
}
|
||||||
|
for t in conversation
|
||||||
|
]
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
conversation,
|
turns,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for instruction-based prompts.
|
Tokenizing strategy for instruction-based prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_messages = "conversations"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
return self._messages
|
||||||
|
|
||||||
|
@messages.setter
|
||||||
|
def messages(self, messages):
|
||||||
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
|
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
@@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
conversations = prompt["conversations"]
|
return prompt[self.messages]
|
||||||
# remap roles - allow for assistant turn
|
|
||||||
role_map = {
|
|
||||||
"human": "user",
|
|
||||||
"user": "user",
|
|
||||||
"assistant": "assistant",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
turns = [
|
|
||||||
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
chat_template = (
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
||||||
)
|
)
|
||||||
|
message_field_role = (
|
||||||
|
ds_cfg["message_field_role"]
|
||||||
|
if ds_cfg and "message_field_role" in ds_cfg
|
||||||
|
else "from"
|
||||||
|
)
|
||||||
|
message_field_content = (
|
||||||
|
ds_cfg["message_field_content"]
|
||||||
|
if ds_cfg and "message_field_content" in ds_cfg
|
||||||
|
else "value"
|
||||||
|
)
|
||||||
|
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
|
ChatTemplatePrompter(
|
||||||
|
tokenizer,
|
||||||
|
chat_templates(chat_template),
|
||||||
|
message_field_role=message_field_role,
|
||||||
|
message_field_content=message_field_content,
|
||||||
|
roles=roles,
|
||||||
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
|
message_field_role: Optional[str] = None
|
||||||
|
message_field_content: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
tests for chat_template prompt strategy
|
tests for chat_template prompt strategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,8 +11,39 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.prompt_strategies.chat_template import (
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
ChatTemplatePrompter,
|
ChatTemplatePrompter,
|
||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
|
load,
|
||||||
)
|
)
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="assistant_dataset")
|
||||||
|
def fixture_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
@pytest.fixture(name="sharegpt_dataset")
|
||||||
@@ -51,6 +83,87 @@ def fixture_llama3_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantChatTemplateLlama3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
strategy = load(
|
||||||
|
llama3_tokenizer,
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"train_on_inputs": False,
|
||||||
|
"sequence_len": 512,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
"field_messages": "messages",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
128000, # bos
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_role="role",
|
||||||
|
message_field_content="content",
|
||||||
|
roles={
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
llama3_tokenizer,
|
||||||
|
False,
|
||||||
|
512,
|
||||||
|
)
|
||||||
|
strategy.messages = "messages"
|
||||||
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
128000, # bos
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class TestSharegptChatTemplateLlama3:
|
class TestSharegptChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
|||||||
Reference in New Issue
Block a user