Lint alpaca_chat

This commit is contained in:
NanoCode012
2023-05-29 13:56:17 +09:00
parent 6abb7f6a16
commit 8cc0aadcb8

View File

@@ -1,3 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
@@ -7,7 +10,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value),
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -15,7 +18,11 @@ def load(tokenizer, cfg):
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
"""
Tokenizing strategy for AlpacaQA
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"",
@@ -25,7 +32,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value),
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,