diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 0f8c31d6a..1183c1e8e 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -20,11 +20,24 @@ def load(tokenizer, cfg): class AlpacaConcisePrompter(AlpacaPrompter): """ - Alpaca Prompter extending the system prompt to ask for concise answers + Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers """ - system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n" - system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n" + system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" + + +class AlpacaChatPrompter(AlpacaPrompter): + """ + Alpaca Chat Prompter extending the system prompt to for chat-instruct answers + """ + + system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n" + + def __init__(self): # pylint: disable=super-init-not-called + self.prompt_style = PromptStyle.CHAT.value + self.match_prompt_style() class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): @@ -64,7 +77,7 @@ def load_concise(tokenizer, cfg): def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.CHAT.value), + AlpacaChatPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -73,7 +86,7 @@ def load_qa(tokenizer, cfg): def load_camel_ai(tokenizer, cfg): return CamelAIPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.CHAT.value), + AlpacaChatPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len,