Lint alpaca_chat

This commit is contained in:
NanoCode012
2023-05-29 13:56:17 +09:00
parent 6abb7f6a16
commit 8cc0aadcb8

View File

@@ -1,3 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy, InstructionPromptTokenizingStrategy,
@@ -7,7 +10,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg): def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value), AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -15,7 +18,11 @@ def load(tokenizer, cfg):
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str): """
Tokenizing strategy for AlpacaQA
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return ( return (
prompt["question"], prompt["question"],
"", "",
@@ -25,7 +32,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def load_qa(tokenizer, cfg): def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy( return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.chat.value), AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,