diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 7b6ccea7d..29a0cb654 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,3 +1,6 @@ +"""Module containing the AlpacaQAPromptTokenizingStrategy class""" + +from typing import Tuple from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, @@ -7,7 +10,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -15,7 +18,11 @@ def load(tokenizer, cfg): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for AlpacaQA + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], "", @@ -25,7 +32,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len,