Add a chat_template prompt strategy for DPO (#1725)
* Implementing a basic chat_template strategy for DPO datasets This mimics the sft chat_template strategy such that users can: * Specify the messages field * Specify the per message role and content fields * speicfy the chosen and rejected fields * Let the tokenizer construct the raw prompt * Ensure the chosen and rejected fields don't have any prefix tokens * Adding additional dpo chat template unittests * Rename test class
This commit is contained in:
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies for using tokenizer chat templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
|
field_messages = ds_cfg.get("field_messages", "messages")
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||||
|
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||||
|
role_map_inv = ds_cfg.get(
|
||||||
|
"roles",
|
||||||
|
{
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map = {}
|
||||||
|
for target, sources in role_map_inv.items():
|
||||||
|
for source in sources:
|
||||||
|
role_map[source] = target
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
messages = sample[field_messages]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": role_map[m[field_message_role]],
|
||||||
|
"content": m[field_message_content],
|
||||||
|
}
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
chosen = {
|
||||||
|
"role": role_map[sample[field_chosen][field_message_role]],
|
||||||
|
"content": sample[field_chosen][field_message_content],
|
||||||
|
}
|
||||||
|
rejected = {
|
||||||
|
"role": role_map[sample[field_rejected][field_message_role]],
|
||||||
|
"content": sample[field_rejected][field_message_content],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[chosen],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
|
result["chosen"] = result["chosen"][chosen_strip_index:]
|
||||||
|
|
||||||
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[rejected],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
|
result["rejected"] = result["rejected"][rejected_strip_index:]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
|||||||
"""Helper function to process and color tokens."""
|
"""Helper function to process and color tokens."""
|
||||||
colored_tokens = [
|
colored_tokens = [
|
||||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
for token in tokenizer.encode(tokens)
|
for token in tokenizer.encode(tokens, add_special_tokens=False)
|
||||||
]
|
]
|
||||||
return colored_tokens
|
return colored_tokens
|
||||||
|
|
||||||
|
|||||||
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
tests for chat_template prompt strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="assistant_dataset")
|
||||||
|
def fixture_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"chosen": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="custom_assistant_dataset")
|
||||||
|
def fixture_custom_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversation": [
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"better": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
"worse": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantDPOChatTemplateLlama3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "better",
|
||||||
|
"field_rejected": "worse",
|
||||||
|
"message_field_role": "speaker",
|
||||||
|
"message_field_content": "text",
|
||||||
|
"roles": {
|
||||||
|
"user": ["human"],
|
||||||
|
"assistant": ["agent"],
|
||||||
|
"system": ["sys"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user