From 8cc0aadcb8e53e50b29514033ff6b86944c71eec Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:56:17 +0900 Subject: [PATCH] Lint alpaca_chat --- src/axolotl/prompt_strategies/alpaca_chat.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 7b6ccea7d..29a0cb654 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,3 +1,6 @@ +"""Module containing the AlpacaQAPromptTokenizingStrategy class""" + +from typing import Tuple from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, @@ -7,7 +10,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -15,7 +18,11 @@ def load(tokenizer, cfg): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for AlpacaQA + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], "", @@ -25,7 +32,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len,