From 01c8a333b3e1b66d64a6ea52f2187e5f5eb38b7e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:05:22 +0900 Subject: [PATCH] Lint pygmalion --- src/axolotl/prompt_strategies/pygmalion.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index ced15c3cf..01828a034 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -1,3 +1,5 @@ +"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" + import copy import logging from collections import defaultdict @@ -9,10 +11,14 @@ IGNORE_TOKEN_ID = -100 class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for Pygmalion. + """ + bot_prefix_token_ids = [] def __init__(self, prompter, tokenizer, *args, **kwargs): - super().__init__(prompter, tokenizer) + super().__init__(prompter, tokenizer, *args, **kwargs) res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) self.bot_prefix_token_ids = res["input_ids"] @@ -23,7 +29,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): "labels": [], } current_len = 0 - for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): + for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): role, message = part if role == "system": prefix = "<|system|>" @@ -96,10 +102,16 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): class PygmalionPrompter: + """ + Prompter for Pygmalion. + """ + def __init__(self, *args, **kwargs): pass - def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: + def build_prompt( + self, source, *args, **kwargs # pylint: disable=unused-argument + ) -> Generator[str, None, None]: for msg in source: yield msg["role"], msg["value"]