From cc11c6bce2c820e40d035c4028d0e320b13eeb76 Mon Sep 17 00:00:00 2001 From: Keith Stevens Date: Wed, 29 May 2024 00:24:13 +0900 Subject: [PATCH] Generalizing the chat_template prompt strategy (#1660) [skip ci] The strategy now supports configuring several fields: * The data field holding message arrays * the role and content fields for each message * role mapping from source to target types additionally this adds a sample llama3-8b instruct template using the chat template --- examples/llama-3/instruct-lora-8b.yml | 76 ++++++++++++ .../prompt_strategies/chat_template.py | 84 ++++++++++--- .../config/models/input/v0_4_1/__init__.py | 2 + .../prompt_strategies/test_chat_templates.py | 113 ++++++++++++++++++ 4 files changed, 258 insertions(+), 17 deletions(-) create mode 100644 examples/llama-3/instruct-lora-8b.yml diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml new file mode 100644 index 000000000..754c9ad5c --- /dev/null +++ b/examples/llama-3/instruct-lora-8b.yml @@ -0,0 +1,76 @@ +base_model: meta-llama/Meta-Llama-3-8B-Instruct +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer + +load_in_8bit: true +load_in_4bit: false +strict: false + +chat_template: llama3 +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + chat_template: llama3 + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: + - user + assistant: + - assistant + +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/lora-out + +sequence_len: 4096 +sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8dff3845b..b052469dc 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -1,24 +1,55 @@ """ HF Chat Templates prompt strategy """ -from typing import Any, Dict, Optional + +import logging +from typing import Any, Dict, List, Optional from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import Prompter from axolotl.utils.chat_templates import chat_templates +LOG = logging.getLogger("axolotl") + class ChatTemplatePrompter(Prompter): """prompter for HF chat templates""" - def __init__(self, tokenizer, chat_template=None, max_length=2048): + def __init__( + self, + tokenizer, + chat_template=None, + max_length=2048, + message_field_role: str = "from", + message_field_content: str = "value", + roles: Optional[Dict[str, List[str]]] = None, + ): + if roles: + self.roles = {s: t for t, sources in roles.items() for s in sources} + else: + self.roles = { + "human": "user", + "user": "user", + "assistant": "assistant", + "gpt": "assistant", + } + self.message_field_role = message_field_role + self.message_field_content = message_field_content self.tokenizer = tokenizer self.chat_template = chat_template self.max_length = max_length def build_prompt(self, conversation, add_generation_prompt=False): + turns = [ + { + "role": self.roles[t[self.message_field_role]], + "content": t[self.message_field_content], + } + for t in conversation + ] + return self.tokenizer.apply_chat_template( - conversation, + turns, truncation=True, max_length=self.max_length, add_generation_prompt=add_generation_prompt, @@ -31,9 +62,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): Tokenizing strategy for instruction-based prompts. """ + _messages = "conversations" + + @property + def messages(self): + return self._messages + + @messages.setter + def messages(self, messages): + self._messages = messages + def tokenize_prompt(self, prompt): turns = self.get_conversation_thread(prompt) - prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True) + prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True) input_ids = self.prompter.build_prompt(turns) if not self.train_on_inputs: @@ -51,28 +92,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return tokenized_prompt def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] - # remap roles - allow for assistant turn - role_map = { - "human": "user", - "user": "user", - "assistant": "assistant", - "gpt": "assistant", - } - turns = [ - {"role": role_map[t["from"]], "content": t["value"]} for t in conversations - ] - return turns + return prompt[self.messages] def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): chat_template = ( ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml" ) + message_field_role = ( + ds_cfg["message_field_role"] + if ds_cfg and "message_field_role" in ds_cfg + else "from" + ) + message_field_content = ( + ds_cfg["message_field_content"] + if ds_cfg and "message_field_content" in ds_cfg + else "value" + ) + roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None + strategy = ChatTemplateStrategy( - ChatTemplatePrompter(tokenizer, chat_templates(chat_template)), + ChatTemplatePrompter( + tokenizer, + chat_templates(chat_template), + message_field_role=message_field_role, + message_field_content=message_field_content, + roles=roles, + ), tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) + if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] return strategy diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index a14b66fa3..f363ebfdc 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -110,6 +110,8 @@ class SFTDataset(BaseModel): field_human: Optional[str] = None field_model: Optional[str] = None field_messages: Optional[str] = None + message_field_role: Optional[str] = None + message_field_content: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 1076c6a3b..7b58a1236 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -1,6 +1,7 @@ """ tests for chat_template prompt strategy """ + import unittest import pytest @@ -10,8 +11,39 @@ from transformers import AutoTokenizer from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, + load, ) from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="assistant_dataset") +def fixture_assistant_dataset(): + # pylint: disable=duplicate-code + return Dataset.from_list( + [ + { + "messages": [ + { + "role": "user", + "content": "hello", + }, + { + "role": "assistant", + "content": "hello", + }, + { + "role": "user", + "content": "goodbye", + }, + { + "role": "assistant", + "content": "goodbye", + }, + ] + } + ] + ) @pytest.fixture(name="sharegpt_dataset") @@ -51,6 +83,87 @@ def fixture_llama3_tokenizer(): return tokenizer +class TestAssistantChatTemplateLlama3: + """ + Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. + """ + + def test_llama3_load(self, llama3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + strategy = load( + llama3_tokenizer, + DictDefault( + { + "train_on_inputs": False, + "sequence_len": 512, + } + ), + DictDefault( + { + "chat_template": "llama3", + "message_field_role": "role", + "message_field_content": "content", + "roles": { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + "field_messages": "messages", + } + ), + ) + res = strategy.tokenize_prompt(assistant_dataset[0]) + input_ids = res["input_ids"] + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + + def test_llama3(self, llama3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + strategy = ChatTemplateStrategy( + ChatTemplatePrompter( + llama3_tokenizer, + chat_templates("llama3"), + message_field_role="role", + message_field_content="content", + roles={ + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + }, + ), + llama3_tokenizer, + False, + 512, + ) + strategy.messages = "messages" + res = strategy.tokenize_prompt(assistant_dataset[0]) + input_ids = res["input_ids"] + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + + class TestSharegptChatTemplateLlama3: """ Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.