From 36596adaf7e94023f719b6de8b117df3c051589c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 01:08:36 +0900 Subject: [PATCH 01/59] Add pre-commit: black+flake8+pylint --- .github/workflows/pre-commit.yml | 16 ++++++++++++++++ .pre-commit-config.yaml | 19 +++++++++++++++++++ requirements-dev.txt | 1 + 3 files changed, 36 insertions(+) create mode 100644 .github/workflows/pre-commit.yml create mode 100644 .pre-commit-config.yaml create mode 100644 requirements-dev.txt diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 000000000..626edc686 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,16 @@ +name: pre-commit + +on: + pull_request: + push: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + cache: 'pip' # caching pip dependencies + - uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..ea958b7f6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 +- repo: https://github.com/PyCQA/pylint + rev: v2.17.4 + hooks: + - id: pylint diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..416634f52 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1 @@ +pre-commit From a98deb31a6304f25f36d740a1097a28167f1d495 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 03:11:19 +0900 Subject: [PATCH 02/59] Add config files --- .flake8 | 5 +++++ .pylintrc | 9 +++++++++ 2 files changed, 14 insertions(+) create mode 100644 .flake8 create mode 100644 .pylintrc diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..edf44df6a --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 88 + +select = C,E,F,W,B,B950 +extend-ignore = E203, E501 diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 000000000..09a8d4013 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,9 @@ +[TYPECHECK] + +# List of members which are set dynamically and missed by Pylint inference +# system, and so shouldn't trigger E1101 when accessed. +generated-members=numpy.*, torch.* + + +[pylint.messages_control] +disable=W1203 From 392dfd9b07ae950d94e512121b4e7e82281b3ef4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 03:45:42 +0900 Subject: [PATCH 03/59] Lint and format --- .gitignore | 2 +- docker/Dockerfile-base | 1 - examples/falcon/config-7b-lora.yml | 1 - examples/falcon/config-7b.yml | 1 - scripts/alpaca_json_to_jsonl.py | 26 +++++++++++--- scripts/finetune.py | 34 ++++++++++--------- src/axolotl/datasets.py | 11 +++--- src/axolotl/utils/data.py | 54 +++++++++++++++++------------- tests/test_prompters.py | 10 +++--- 9 files changed, 82 insertions(+), 58 deletions(-) diff --git a/.gitignore b/.gitignore index 93a4f81b5..614a6676b 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ \ No newline at end of file +.idea/ diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index a61f6d42d..0ce43b621 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \ pip3 install awscli && \ # The base image ships with `pydantic==1.8.2` which is not working pip3 install -U --no-cache-dir pydantic - diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 1291198cf..090cc6bcf 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -61,4 +61,3 @@ special_tokens: pad_token: "<|endoftext|>" bos_token: ">>ABSTRACT<<" eos_token: "<|endoftext|>" - diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 787c4121c..dc67d6125 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -61,4 +61,3 @@ special_tokens: pad_token: "<|endoftext|>" bos_token: ">>ABSTRACT<<" eos_token: "<|endoftext|>" - diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index 98c968309..f535d1afc 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -1,23 +1,39 @@ +"""Module to convert json file to jsonl""" + import os import sys + +from typing import Optional from pathlib import Path import fire -from typing import Optional + + +from axolotl.convert import ( + FileReader, + StdoutWriter, + FileWriter, + JsonlSerializer, + JsonParser, + JsonToJsonlConverter, +) + # add src to the pythonpath so we don't need to pip install this project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) -from axolotl.convert import * - def main( - input: Path, + file: Path, output: Optional[Path] = None, to_stdout: Optional[bool] = False, ): + """ + Convert a json file to jsonl + """ + file_reader = FileReader() if to_stdout or output is None: writer = StdoutWriter() @@ -28,7 +44,7 @@ def main( converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer) - converter.convert(input, output) + converter.convert(file, output) if __name__ == "__main__": diff --git a/scripts/finetune.py b/scripts/finetune.py index 58f1c0957..029e94648 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,3 +1,5 @@ +"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" + import importlib import logging import os @@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.validation import validate_config from axolotl.utils.dict import DictDefault -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - from axolotl.utils.data import load_prepare_datasets from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + + logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" @@ -37,7 +40,7 @@ def choose_device(cfg): try: if torch.backends.mps.is_available(): return "mps" - except: + except Exception: # pylint: disable=broad-exception-caught return "cpu" cfg.device = get_device() @@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): model.eval() with torch.no_grad(): - # gc = GenerationConfig() # TODO swap out and use this + # gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme generated = model.generate( inputs=batch["input_ids"].to(cfg.device), do_sample=True, @@ -130,12 +133,12 @@ def train( config = choose_config(config) # load the config from the yaml file - with open(config, "r") as f: - cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader)) + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() - for k in kwargs: + for k, _ in kwargs.items(): # if not strict, allow writing to cfg even if it's not in the yml already if k in cfg_keys or cfg.strict is False: # handle booleans @@ -167,13 +170,11 @@ def train( # load the tokenizer first logging.info("loading tokenizer...") - tokenizer = load_tokenizer( - cfg.base_model_config, - cfg.tokenizer_type, - cfg - ) + tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) - if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these + if check_not_in( + ["inference", "shard", "merge_lora"], kwargs + ): # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -262,10 +263,13 @@ def train( logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + # pylint: disable=fixme # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: model.save_pretrained(cfg.output_dir) + + # pylint: disable=fixme # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 0e166f6f0..c7bb9fbfe 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset): else: example_len = 0 - if ( - not example_len - or buffer_len + int(add_concat_token) + example_len - > self.seq_length + if not example_len or ( + buffer_len + int(add_concat_token) + example_len > self.seq_length ): if buffer["input_ids"]: input_ids = torch.cat(buffer["input_ids"], dim=-1)[ @@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset): : self.seq_length ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if ( - labels.size() == input_ids.size() - and attention_mask.size() == input_ids.size() + if labels.size() == input_ids.size() and ( + attention_mask.size() == input_ids.size() ): yield { "input_ids": input_ids, diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index a0cff21c4..6d2123eea 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,14 +1,12 @@ import logging from hashlib import md5 from pathlib import Path -from typing import Union +from typing import Tuple, Union from datasets import ( load_from_disk, load_dataset, - IterableDataset, Dataset, - concatenate_datasets, DatasetDict, ) from huggingface_hub import hf_hub_download @@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets( md5( ( str(cfg.sequence_len) - + "@" - + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) - + "|" - + tokenizer_name + + "@" # noqa: W503 + + "|".join( # noqa: W503 + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + ) + + "|" # noqa: W503 + + tokenizer_name # noqa: W503 ).encode("utf-8") ).hexdigest() ) @@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] - except: + except Exception: # pylint: disable=broad-except pass if dataset: @@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets( fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) + ds: Dataset = load_dataset( + "json", data_files=fp, streaming=False, split=None + ) if not ds: - raise Exception("unhandled dataset load") + raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) + ds: DatasetDict = ds.shuffle(seed=42)["train"].shard( + num_shards=d.shards, index=0 + ) else: - ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) + ds: Dataset = ds.shuffle(seed=42).shard( + num_shards=d.shards, index=0 + ) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] @@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets( def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path -) -> (Dataset, Dataset): +) -> Tuple[Dataset, Dataset]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -259,12 +265,14 @@ def load_prepare_datasets( md5( ( str(cfg.sequence_len) - + "@" - + str(max_packed_sequence_len) - + seed - + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) - + "|" - + tokenizer_name + + "@" # noqa: W503 + + str(max_packed_sequence_len) # noqa: W503 + + seed # noqa: W503 + + "|".join( # noqa: W503 + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + ) + + "|" # noqa: W503 + + tokenizer_name # noqa: W503 ).encode("utf-8") ).hexdigest() ) @@ -285,7 +293,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] - except: + except Exception: # pylint: disable=broad-except pass if dataset: @@ -327,9 +335,9 @@ def load_prepare_datasets( d for d in dataset if len(d["input_ids"]) < cfg.sequence_len - and len(d["input_ids"]) > 0 - and len(d["input_ids"]) == len(d["attention_mask"]) - and len(d["input_ids"]) == len(d["labels"]) + and len(d["input_ids"]) > 0 # noqa: W503 + and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503 + and len(d["input_ids"]) == len(d["labels"]) # noqa: W503 ] ) diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 1c3c13852..b4a34c6c0 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase): def test_prompt_style_w_instruct(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) - res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) assert "Below is an instruction" in res assert "### Instruction:" in res assert "### Input:" in res @@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase): def test_prompt_style_w_chat(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) - res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) assert "Below is an instruction" in res assert "### Instruction:" not in res assert "### Input:" not in res @@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Response:" not in res assert "USER:" in res assert "ASSISTANT:" in res - - From c3a46970167cc05508bc270596f6b26e551b6dce Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 03:53:55 +0900 Subject: [PATCH 04/59] Update ignores --- .flake8 | 2 +- .pylintrc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.flake8 b/.flake8 index edf44df6a..fd69af775 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ max-line-length = 88 select = C,E,F,W,B,B950 -extend-ignore = E203, E501 +extend-ignore = E203, E501, W503 diff --git a/.pylintrc b/.pylintrc index 09a8d4013..9cf1babc3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -6,4 +6,4 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=W1203 +disable=W1203, C0116, C0301 From d57ba56746c083f736b92900fa6b5b2417548762 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:20:04 +0900 Subject: [PATCH 05/59] Ignore import and too many * pylint errors --- .pylintrc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 9cf1babc3..11180e7f6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,3 +1,6 @@ +[MASTER] +init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" + [TYPECHECK] # List of members which are set dynamically and missed by Pylint inference @@ -6,4 +9,4 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=W1203, C0116, C0301 +disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915 From cb7cd3429fba1aa83d7827759f6e09e2441de409 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:21:08 +0900 Subject: [PATCH 06/59] Fix data.py lint --- src/axolotl/utils/data.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6d2123eea..32654e104 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,3 +1,5 @@ +"""Module containing data utilities for Axolotl""" + import logging from hashlib import md5 from pathlib import Path @@ -46,12 +48,12 @@ def load_tokenized_prepared_datasets( md5( ( str(cfg.sequence_len) - + "@" # noqa: W503 - + "|".join( # noqa: W503 + + "@" + + "|".join( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) ) - + "|" # noqa: W503 - + tokenizer_name # noqa: W503 + + "|" + + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -81,6 +83,7 @@ def load_tokenized_prepared_datasets( logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") logging.info("Loading raw datasets...") datasets = [] + # pylint: disable=invalid-name for d in cfg.datasets: ds: Union[Dataset, DatasetDict] = None ds_from_hub = False @@ -229,7 +232,7 @@ def load_tokenized_prepared_datasets( samples = [] for d in datasets: - samples = samples + [i for i in d] + samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=42) if cfg.local_rank == 0: logging.info( @@ -265,14 +268,14 @@ def load_prepare_datasets( md5( ( str(cfg.sequence_len) - + "@" # noqa: W503 - + str(max_packed_sequence_len) # noqa: W503 - + seed # noqa: W503 - + "|".join( # noqa: W503 + + "@" + + str(max_packed_sequence_len) + + seed + + "|".join( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) ) - + "|" # noqa: W503 - + tokenizer_name # noqa: W503 + + "|" + + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -327,7 +330,7 @@ def load_prepare_datasets( logging.info( f"packing master dataset to len: {cfg.max_packed_sequence_len}" ) - dataset = Dataset.from_list([_ for _ in constant_len_dataset]) + dataset = Dataset.from_list(list(constant_len_dataset)) # filter out bad data dataset = Dataset.from_list( @@ -335,9 +338,9 @@ def load_prepare_datasets( d for d in dataset if len(d["input_ids"]) < cfg.sequence_len - and len(d["input_ids"]) > 0 # noqa: W503 - and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503 - and len(d["input_ids"]) == len(d["labels"]) # noqa: W503 + and len(d["input_ids"]) > 0 + and len(d["input_ids"]) == len(d["attention_mask"]) + and len(d["input_ids"]) == len(d["labels"]) ] ) From 903ea3080dd4acb69dd6a196079472eaa1b3882a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:37:44 +0900 Subject: [PATCH 07/59] Fix lint --- src/axolotl/prompt_strategies/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 803eb970c..2f6af208c 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -1,3 +1,5 @@ +"""Module to load prompt strategies.""" + import importlib @@ -7,8 +9,8 @@ def load(strategy, tokenizer, cfg): if strategy.split(".")[-1].startswith("load_"): load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) - m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") - fn = getattr(m, load_fn) - return fn(tokenizer, cfg) - except: - pass + mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") + func = getattr(mod, load_fn) + return func(tokenizer, cfg) + except Exception: # pylint: disable=broad-exception-caught + return None From 1c60c10e0076de9ff6e9c115a6919281ba3b3ecd Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:50:36 +0900 Subject: [PATCH 08/59] Lint flash_attn.py --- src/axolotl/flash_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index c1ceec788..d532e15a8 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -1,9 +1,10 @@ +"""Flash attention monkey patch for llama model""" + # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch -from torch import nn import transformers from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -14,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input -def forward( +def forward( # pylint: disable=too-many-arguments self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -82,6 +83,8 @@ def forward( output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: nheads = qkv.shape[-2] + + # pylint: disable=invalid-name x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( @@ -104,13 +107,13 @@ def forward( # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( self, attention_mask, input_shape, inputs_embeds, past_key_values_length -): +): # pylint: disable=unused-argument # [bsz, seq_len] return attention_mask def replace_llama_attn_with_flash_attn(): - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = forward From 4c0eddb3f8743f77c91fcaac152a45c7d0d5584f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 09:51:45 +0900 Subject: [PATCH 09/59] Refactor --- src/axolotl/utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 32654e104..3164f2ecc 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,4 +1,4 @@ -"""Module containing data utilities for Axolotl""" +"""Module containing data utilities""" import logging from hashlib import md5 From cb4f0e93420dc470a9fb75c5e28a37d4e68f544e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:03:13 +0900 Subject: [PATCH 10/59] Lint prompters.py --- src/axolotl/prompters.py | 98 ++++++++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 25 deletions(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 760c714d6..784a86c13 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,28 +1,37 @@ -import copy +"""Module containing prompters""" + import dataclasses import logging from enum import auto, Enum -from typing import List, Tuple, Any, Union, Generator +from typing import List, Union, Generator IGNORE_TOKEN_ID = -100 class PromptStyle(Enum): - instruct = "instruct" - chat = "chat" + """ + Enum for prompt styles + """ + + INSTRUCT = "instruct" + CHAT = "chat" class AlpacaPrompter: + """ + Base class for alpaca prompters + """ + system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" prompt_style = None - def __init__(self, prompt_style=PromptStyle.instruct.value): - self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value + def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): + self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value self.match_prompt_style() def match_prompt_style(self): - if self.prompt_style == PromptStyle.instruct.value: + if self.prompt_style == PromptStyle.INSTRUCT.value: self.prompt_input = ( self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" @@ -32,7 +41,7 @@ class AlpacaPrompter: + "### Instruction:\n{instruction}\n\n### Response:\n" ) self.response_split = "### Response:" - if self.prompt_style == PromptStyle.chat.value: + if self.prompt_style == PromptStyle.CHAT.value: self.prompt_input = ( self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" ) @@ -44,7 +53,7 @@ class AlpacaPrompter: def build_prompt( self, instruction: str, - input: Union[None, str] = None, + input: Union[None, str] = None, # pylint: disable=redefined-builtin output: Union[None, str] = None, ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input @@ -62,33 +71,60 @@ class AlpacaPrompter: class UnpromptedPrompter(AlpacaPrompter): + """ + Prompter for alpaca no system prompt + """ + system_prompt = "" system_no_input_prompt = "" class JeopardyPrompter(AlpacaPrompter): + """ + Prompter for Jeopardy + """ + prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" class MultipleChoiceExplainPrompter(AlpacaPrompter): + """ + Prompter for multiple choice explain + """ + system_prompt = ( "Choose the answer that best answers the question. Explain your reasoning." ) class MultipleChoiceConcisePrompter(AlpacaPrompter): + """ + Prompter for multiple choice concise + """ + prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n" class SummarizeTLDRPrompter(AlpacaPrompter): + """ + Prompter for summarize TLDR + """ + prompt_no_input = ( "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:" ) class CompletionPrompter: + """ + Prompter for completion + """ + def build_prompt( - self, instruction: str, input=None, output=None + self, + instruction: str, + input=None, # pylint: disable=redefined-builtin, unused-argument + output=None, # pylint: disable=unused-argument ) -> Generator[str, None, None]: yield instruction @@ -97,14 +133,22 @@ class CompletionPrompter: class GPTeacherPrompter(AlpacaPrompter): - ... + """ + Prompter for GPTeacher + """ class NomicGPT4AllPrompter(AlpacaPrompter): - ... + """ + Prompter for NomicGPT4All + """ class ReflectAlpacaPrompter: + """ + Prompter for ReflectAlpaca + """ + system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" @@ -120,7 +164,7 @@ class ReflectAlpacaPrompter: self.match_prompt_style() def match_prompt_style(self): - if self.prompt_style == PromptStyle.instruct.value: + if self.prompt_style == PromptStyle.INSTRUCT.value: self.prompt_input = ( self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" @@ -131,7 +175,7 @@ class ReflectAlpacaPrompter: ) self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" self.response_split = "### Final Response:" - if self.prompt_style == PromptStyle.chat.value: + if self.prompt_style == PromptStyle.CHAT.value: self.prompt_input = ( self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" ) @@ -143,10 +187,10 @@ class ReflectAlpacaPrompter: ) self.response_split = "ASSISTANT:" - def build_prompt( + def build_prompt( # pylint: disable=too-many-arguments self, instruction: str, - input: Union[None, str] = None, + input: Union[None, str] = None, # pylint: disable=redefined-builtin output: Union[None, str] = None, reflection: Union[None, str] = None, corrected: Union[None, str] = None, @@ -176,7 +220,7 @@ class SeparatorStyle(Enum): DOLLY = auto() -# TODO clean this 💩 up +# TODO clean this 💩 up # pylint: disable=fixme @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" @@ -193,11 +237,11 @@ class Conversation: seps = [self.sep, self.sep2] preamble = self.system + seps[0] yield preamble - for i, (role, message) in enumerate(self.messages): + for _, (role, message) in enumerate(self.messages): if message: yield (role + ":", " " + message) else: - logging.warning("role with empty message: " + role) + logging.warning(f"role with empty message: {role}") yield (role + ":",) def copy(self): @@ -227,10 +271,14 @@ conv_vicuna_v1_1 = Conversation( ) -class ShareGPTPrompter: +class ShareGPTPrompter: # pylint: disable=too-few-public-methods + """ + A prompter that generates prompts for the ShareGPT + """ + def __init__(self, prompt_style=None): - if prompt_style != PromptStyle.chat.value: - raise Exception( + if prompt_style != PromptStyle.CHAT.value: + raise ValueError( f"unsupported prompt_style for ShareGPTPrompter({prompt_style})" ) @@ -240,7 +288,7 @@ class ShareGPTPrompter: # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" # self.response_split = "ASSISTANT:" - def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: + def build_prompt(self, source) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": source.pop(0) @@ -261,9 +309,9 @@ class ShareGPTPrompter: ): # Skip the first one if it is not from human source = source[1:] - except IndexError as e: + except IndexError as err: # sometimes there is a bing or system chat - raise e + raise err conv.messages = [] for j, sentence in enumerate(source): From 5062eca069780d25b0ec1afa8ba6f04d4a5f0864 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:05:11 +0900 Subject: [PATCH 11/59] Lint callbacks.py --- src/axolotl/utils/callbacks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 229cd9b98..70e83d6e4 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,7 +1,8 @@ +"""Callbacks for Trainer class""" + import os from transformers import ( - Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, @@ -10,7 +11,9 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -class SavePeftModelCallback(TrainerCallback): +class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods + """Callback to save the PEFT adapter""" + def on_save( self, args: TrainingArguments, From 54c3b5b25ff86a43f565fa9602d97d5bb038e6ba Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:08:23 +0900 Subject: [PATCH 12/59] Ignore too-many-arguments --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 11180e7f6..53fd6533e 100644 --- a/.pylintrc +++ b/.pylintrc @@ -9,4 +9,4 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915 +disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915, R0913 From e8717d3bef74a4cb9ce8e0cbda2d1da78e64c4bb Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:08:30 +0900 Subject: [PATCH 13/59] Remove disable --- src/axolotl/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index d532e15a8..c7bd12c66 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -15,7 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input -def forward( # pylint: disable=too-many-arguments +def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, From 5658717dbd4f523a0b08b67c41865d0564413f72 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:16:31 +0900 Subject: [PATCH 14/59] Remove disable too many arg --- src/axolotl/prompters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 784a86c13..418758ef7 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -187,7 +187,7 @@ class ReflectAlpacaPrompter: ) self.response_split = "ASSISTANT:" - def build_prompt( # pylint: disable=too-many-arguments + def build_prompt( self, instruction: str, input: Union[None, str] = None, # pylint: disable=redefined-builtin From 69722aeef4dabcb54e010e97fd20e5712a333738 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:22:09 +0900 Subject: [PATCH 15/59] Remove fixme disable --- .pylintrc | 2 +- src/axolotl/prompters.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pylintrc b/.pylintrc index 53fd6533e..512af30e0 100644 --- a/.pylintrc +++ b/.pylintrc @@ -9,4 +9,4 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915, R0913 +disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915, R0913, fixme diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 418758ef7..eced1d4a5 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -220,7 +220,7 @@ class SeparatorStyle(Enum): DOLLY = auto() -# TODO clean this 💩 up # pylint: disable=fixme +# TODO clean this 💩 up @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" From 545cfeb5c76a55023e23cdddd0dd31430383ff48 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 10:30:32 +0900 Subject: [PATCH 16/59] Refactor error code to use full error message --- .pylintrc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 512af30e0..7535c7e89 100644 --- a/.pylintrc +++ b/.pylintrc @@ -9,4 +9,8 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=W1203, C0116, C0301, E0401, R0912, R0914, R0915, R0913, fixme +disable=missing-function-docstring, line-too-long, import-error + too-many-arguments, too-many-locals, too-many-statements, too-many-branches + fixme, + import-outside-toplevel, + logging-fstring-interpolation, From daf47ccf45420b81caec193ac0dcacec4ca51799 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:18:25 +0900 Subject: [PATCH 17/59] Refactor disable pylint --- .pylintrc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.pylintrc b/.pylintrc index 7535c7e89..bac1aaa66 100644 --- a/.pylintrc +++ b/.pylintrc @@ -9,8 +9,6 @@ generated-members=numpy.*, torch.* [pylint.messages_control] -disable=missing-function-docstring, line-too-long, import-error - too-many-arguments, too-many-locals, too-many-statements, too-many-branches - fixme, - import-outside-toplevel, - logging-fstring-interpolation, +disable=missing-function-docstring, line-too-long, import-error, + too-many-arguments, too-many-locals, too-many-statements, too-many-branches, + fixme, import-outside-toplevel, logging-fstring-interpolation, From f4e5d862682dbb7efb1a2d128c28689bae915c17 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:32:04 +0900 Subject: [PATCH 18/59] Lint models.py --- src/axolotl/utils/models.py | 64 ++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 07872a16e..7a81b8a49 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,13 +1,16 @@ +"""Module for models and model loading""" + + import logging import math import os from pathlib import Path -from typing import Optional, Tuple, TYPE_CHECKING +from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import ( +from transformers import ( # noqa: F401 AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -18,9 +21,8 @@ from transformers import ( try: from transformers import ( LlamaForCausalLM, - LlamaTokenizer, ) -except: +except ImportError: logging.warning( "This version of transformers does not support Llama. Consider upgrading." ) @@ -28,9 +30,9 @@ except: from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: - from peft import PeftModel, PeftConfig - from axolotl.utils.dict import DictDefault - from transformers import PreTrainedTokenizer + from peft import PeftConfig # noqa: F401 + from axolotl.utils.dict import DictDefault # noqa: F401 + from transformers import PreTrainedTokenizer # noqa: F401 def load_tokenizer( @@ -62,8 +64,8 @@ def load_tokenizer( os.environ["TOKENIZERS_PARALLELISM"] = "false" if cfg.special_tokens: - for k, v in cfg.special_tokens.items(): - tokenizer.add_special_tokens({k: v}) + for k, val in cfg.special_tokens.items(): + tokenizer.add_special_tokens({k: val}) if cfg.tokens: tokenizer.add_tokens(list(cfg.tokens)) @@ -80,6 +82,9 @@ def load_model( inference=False, ): # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + """ + Load a model from a base model and a model type. + """ # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit @@ -115,9 +120,9 @@ def load_model( replace_peft_model_with_int4_lora_model() from peft import prepare_model_for_int8_training - except Exception as e: - logging.exception(e) - raise e + except Exception as err: + logging.exception(err) + raise err model_kwargs = {} if cfg.adapter == "qlora" and cfg.load_in_4bit: @@ -155,7 +160,7 @@ def load_model( "unable to find a cached model file, this will likely fail..." ) model_path = str(cache_model_path) - except: + except Exception: # pylint: disable=broad-exception-caught model_path = cfg.base_model model, _ = load_llama_model_4bit_low_ram( base_model_config if base_model_config else base_model, @@ -210,13 +215,13 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, - trust_remote_code=True if cfg.trust_remote_code is True else False, + trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) else: config = AutoConfig.from_pretrained( base_model, - trust_remote_code=True if cfg.trust_remote_code is True else False, + trust_remote_code=cfg.trust_remote_code or False, ) model = AutoModelForCausalLM.from_pretrained( base_model, @@ -225,30 +230,29 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, - trust_remote_code=True if cfg.trust_remote_code is True else False, + trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) - except Exception as e: + except Exception as err: # pylint: disable=broad-exception-caught logging.error( "Exception raised attempting to load model, retrying with AutoModelForCausalLM" ) - logging.exception(e) + logging.exception(err) model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, - trust_remote_code=True if cfg.trust_remote_code is True else False, + trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) - if ( - ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora") - and not cfg.gptq - and (load_in_8bit or cfg.load_in_4bit) + if not cfg.gptq and ( + (cfg.adapter == "lora" and load_in_8bit) + or (cfg.adapter == "qlora" and cfg.load_in_4bit) ): logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model) @@ -261,14 +265,14 @@ def load_model( if cfg.gptq: # Scales to half logging.info("Fitting 4bit scales and zeros to half") - for n, m in model.named_modules(): - if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str( - type(m) + for _, module in model.named_modules(): + if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str( + type(module) ): - if hasattr(m, "is_v1_model") and m.is_v1_model: - m.zeros = m.zeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() + if hasattr(module, "is_v1_model") and module.is_v1_model: + module.zeros = module.zeros.half() + module.scales = module.scales.half() + module.bias = module.bias.half() if ( torch.cuda.device_count() > 1 From 82971e1565abf3f7820ef98d16d16dcccda3dbd9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:40:12 +0900 Subject: [PATCH 19/59] Lint finetune.py --- scripts/finetune.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 029e94648..226068020 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -34,14 +34,16 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" def choose_device(cfg): def get_device(): - if torch.cuda.is_available(): - return f"cuda:{cfg.local_rank}" - else: - try: - if torch.backends.mps.is_available(): - return "mps" - except Exception: # pylint: disable=broad-exception-caught - return "cpu" + try: + if torch.cuda.is_available(): + return f"cuda:{cfg.local_rank}" + + if torch.backends.mps.is_available(): + return "mps" + + raise SystemError("No CUDA/mps device found") + except Exception: # pylint: disable=broad-exception-caught + return "cpu" cfg.device = get_device() if cfg.device == "cuda": @@ -54,7 +56,7 @@ def get_multi_line_input() -> Optional[str]: print("Give me an instruction (Ctrl + D to finish): ") instruction = "" for line in sys.stdin: - instruction += line + instruction += line # pylint: disable=consider-using-join # instruction = pathlib.Path("/proc/self/fd/0").read_text() return instruction @@ -76,7 +78,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): model.eval() with torch.no_grad(): - # gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme + # gc = GenerationConfig() # TODO swap out and use this generated = model.generate( inputs=batch["input_ids"].to(cfg.device), do_sample=True, @@ -95,7 +97,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def choose_config(path: Path): - yaml_files = [file for file in path.glob("*.yml")] + yaml_files = list(path.glob("*.yml")) if not yaml_files: raise ValueError( @@ -240,7 +242,7 @@ def train( if cfg.local_rank == 0: signal.signal( signal.SIGINT, - lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), + lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)), ) logging.info("Starting trainer...") @@ -263,13 +265,11 @@ def train( logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") - # pylint: disable=fixme # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: model.save_pretrained(cfg.output_dir) - # pylint: disable=fixme # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time From 1a2bd7ff6212abc7986f5616d6b29d6ec77a84bc Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:44:28 +0900 Subject: [PATCH 20/59] Ignore too-few-public-methods --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index bac1aaa66..982bad52d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -10,5 +10,5 @@ generated-members=numpy.*, torch.* [pylint.messages_control] disable=missing-function-docstring, line-too-long, import-error, - too-many-arguments, too-many-locals, too-many-statements, too-many-branches, + too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, fixme, import-outside-toplevel, logging-fstring-interpolation, From ddb86ea82144c193f8e79656a5461f819f4f219f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:46:49 +0900 Subject: [PATCH 21/59] Lint trainer.py --- src/axolotl/utils/trainer.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 97b02baba..299e39664 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,3 +1,5 @@ +"""Module containing the Trainer class and related functions""" + import importlib import math import os @@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback class OneCycleLRSchedulerTrainer(Trainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - num_training_steps = num_training_steps pct_start = num_warmup_steps / num_training_steps self.lr_scheduler = OneCycleLR( @@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["bf16_full_eval"] = True else: training_arguments_kwargs["bf16"] = cfg.bf16 - training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False + training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps - if cfg.gradient_checkpointing is not None: + if cfg.gradient_checkpointing: if cfg.gptq: from alpaca_lora_4bit.gradient_checkpointing import ( apply_gradient_checkpointing, @@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, - load_best_model_at_end=True - if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False - and cfg.val_set_size > 0 - and save_steps is not None - and save_steps % eval_steps == 0 - and cfg.load_in_8bit is not True - else False, + load_best_model_at_end=( + cfg.val_set_size > 0 + and save_steps + and save_steps % eval_steps == 0 + and cfg.load_in_8bit is not True + ) + or False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, @@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if ( cfg.optimizer == "adamw_bnb_8bit" and not cfg.gptq - and not "deepspeed" in training_arguments_kwargs + and "deepspeed" not in training_arguments_kwargs and not cfg.fsdp ): decay_parameters = get_parameter_names(model, [nn.LayerNorm]) From 8b617cc7f6a0ecd57e350dafe1a27fe968ec3bda Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:48:07 +0900 Subject: [PATCH 22/59] Lint setup.py --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 134e4be66..7f51f495f 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,9 @@ +"""setup.py for axolotl""" + from setuptools import setup, find_packages install_requires = [] -with open("./requirements.txt", "r") as requirements_file: +with open("./requirements.txt", encoding="utf-8") as requirements_file: # don't include peft yet until we check the int4 # need to manually install peft for now... reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r] From de2406c4884ec109ed40a421e0eadb1539e1e91b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:51:43 +0900 Subject: [PATCH 23/59] Lint convert.py --- src/axolotl/convert.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py index a953252e9..357e0ec50 100644 --- a/src/axolotl/convert.py +++ b/src/axolotl/convert.py @@ -1,47 +1,76 @@ +"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes""" + + import json import sys class FileReader: + """ + Reads a file and returns its contents as a string + """ + def read(self, file_path): - with open(file_path, "r") as file: + with open(file_path, encoding="utf-8") as file: return file.read() class FileWriter: + """ + Writes a string to a file + """ + def __init__(self, file_path): self.file_path = file_path def write(self, content): - with open(self.file_path, "w") as file: + with open(self.file_path, "w", encoding="utf-8") as file: file.write(content) class StdoutWriter: + """ + Writes a string to stdout + """ + def write(self, content): sys.stdout.write(content) sys.stdout.write("\n") class JsonParser: + """ + Parses a string as JSON and returns the result + """ + def parse(self, content): return json.loads(content) class JsonlSerializer: + """ + Serializes a list of JSON objects into a JSONL string + """ + def serialize(self, data): lines = [json.dumps(item) for item in data] return "\n".join(lines) class JsonToJsonlConverter: + """ + Converts a JSON file to JSONL + """ + def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer): self.file_reader = file_reader self.file_writer = file_writer self.json_parser = json_parser self.jsonl_serializer = jsonl_serializer - def convert(self, input_file_path, output_file_path): + def convert( + self, input_file_path, output_file_path + ): # pylint: disable=unused-argument content = self.file_reader.read(input_file_path) data = self.json_parser.parse(content) # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations From 6abb7f6a1603f4955f0e3851f7bbcd4026f76ad7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:54:06 +0900 Subject: [PATCH 24/59] Lint datasets --- src/axolotl/datasets.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index c7bb9fbfe..1e72be114 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,3 +1,5 @@ +"""Module containing Dataset functionality""" + import logging from typing import List @@ -14,7 +16,14 @@ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException class TokenizedPromptDataset(IterableDataset): - def __init__( + """ + Iterable dataset that returns tokenized prompts from a stream of text files. + Args: + prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + """ + + def __init__( # pylint: disable=super-init-not-called self, prompt_tokenizer: PromptTokenizingStrategy, dataset: IterableDataset, @@ -42,7 +51,7 @@ class ConstantLengthDataset(IterableDataset): seq_length (int): Length of token sequences to return. """ - def __init__( + def __init__( # pylint: disable=super-init-not-called self, tokenizer, datasets, From 8cc0aadcb8e53e50b29514033ff6b86944c71eec Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:56:17 +0900 Subject: [PATCH 25/59] Lint alpaca_chat --- src/axolotl/prompt_strategies/alpaca_chat.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 7b6ccea7d..29a0cb654 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,3 +1,6 @@ +"""Module containing the AlpacaQAPromptTokenizingStrategy class""" + +from typing import Tuple from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, @@ -7,7 +10,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -15,7 +18,11 @@ def load(tokenizer, cfg): class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for AlpacaQA + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], "", @@ -25,7 +32,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): def load_qa(tokenizer, cfg): return AlpacaQAPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.chat.value), + AlpacaPrompter(PromptStyle.CHAT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, From 145b060cbe9839220bdccdc4448a5e5df61a48a3 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 13:57:09 +0900 Subject: [PATCH 26/59] Lint alpaca_instruct --- src/axolotl/prompt_strategies/alpaca_instruct.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/alpaca_instruct.py b/src/axolotl/prompt_strategies/alpaca_instruct.py index 6bce47ccd..0d0b267a6 100644 --- a/src/axolotl/prompt_strategies/alpaca_instruct.py +++ b/src/axolotl/prompt_strategies/alpaca_instruct.py @@ -1,10 +1,12 @@ +"""Module loading the AlpacaInstructPromptTokenizingStrategy class""" + from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter, PromptStyle def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.instruct), + AlpacaPrompter(PromptStyle.INSTRUCT), tokenizer, cfg.train_on_inputs, cfg.sequence_len, From 1645a4ddd5a4261798b80991d3a40ebe7c4124ad Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:01:26 +0900 Subject: [PATCH 27/59] Lint creative_acr --- src/axolotl/prompt_strategies/creative_acr.py | 42 ++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py index 58e8b2bee..5cf89127d 100644 --- a/src/axolotl/prompt_strategies/creative_acr.py +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -1,11 +1,17 @@ -from typing import Union, Generator +"""Module loading the CreativePromptTokenizingStrategy and similar classes""" + +from typing import Tuple, Union, Generator import yaml from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Creative Answering + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: question = prompt["instruction"] answer = prompt[ "revision" @@ -18,6 +24,10 @@ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrat class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for Creative Critique + """ + user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria: refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias. @@ -49,7 +59,7 @@ Question: {question} Answer: {answer} """ - def parse_instruction_fields(self, prompt) -> (str, str, str): + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper ) @@ -67,6 +77,10 @@ Answer: {answer} class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for Creative Revise + """ + user_prompt = """Definitions: refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias. @@ -81,7 +95,7 @@ Evaluation: {evaluation} """ - def parse_instruction_fields(self, prompt) -> (str, str, str): + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper ) @@ -101,13 +115,19 @@ Evaluation: class CreativePrompterBase: + """ + Base class for Creative Prompters + """ + system_prompt = "" prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:" def build_prompt( self, instruction: str, - input: Union[None, str] = None, + input: Union[ # pylint: disable=redefined-builtin, unused-argument + None, str + ] = None, output: Union[None, str] = None, ) -> Generator[str, None, None]: if self.system_prompt: @@ -120,14 +140,26 @@ class CreativePrompterBase: class CreativeAnswerPrompter(CreativePrompterBase): + """ + Prompter for Creative Answering + """ + system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity." class CreativeCritiquePrompter(CreativePrompterBase): + """ + Prompter for Creative Critique + """ + system_prompt = "" class CreativeRevisePrompter(CreativePrompterBase): + """ + Prompter for Creative Revise + """ + system_prompt = "" From 7eb33a77dde99a5434eab2c9ddd4c66acd3cee05 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:02:43 +0900 Subject: [PATCH 28/59] Lint test_prompters --- tests/test_prompters.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_prompters.py b/tests/test_prompters.py index b4a34c6c0..11610ccc5 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -1,9 +1,15 @@ +"""Module testing prompters""" + import unittest from axolotl.prompters import AlpacaPrompter, PromptStyle class AlpacaPrompterTest(unittest.TestCase): + """ + Test AlpacaPrompter + """ + def test_prompt_style_w_none(self): prompter = AlpacaPrompter(prompt_style=None) res = next(prompter.build_prompt("tell me a joke")) @@ -11,7 +17,7 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Instruction:" in res def test_prompt_style_w_instruct(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) + prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value) res = next( prompter.build_prompt("tell me a joke about the following", "alpacas") ) @@ -31,7 +37,7 @@ class AlpacaPrompterTest(unittest.TestCase): assert "ASSISTANT:" not in res def test_prompt_style_w_chat(self): - prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) + prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value) res = next( prompter.build_prompt("tell me a joke about the following", "alpacas") ) From 01c8a333b3e1b66d64a6ea52f2187e5f5eb38b7e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:05:22 +0900 Subject: [PATCH 29/59] Lint pygmalion --- src/axolotl/prompt_strategies/pygmalion.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index ced15c3cf..01828a034 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -1,3 +1,5 @@ +"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" + import copy import logging from collections import defaultdict @@ -9,10 +11,14 @@ IGNORE_TOKEN_ID = -100 class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for Pygmalion. + """ + bot_prefix_token_ids = [] def __init__(self, prompter, tokenizer, *args, **kwargs): - super().__init__(prompter, tokenizer) + super().__init__(prompter, tokenizer, *args, **kwargs) res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) self.bot_prefix_token_ids = res["input_ids"] @@ -23,7 +29,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): "labels": [], } current_len = 0 - for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): + for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): role, message = part if role == "system": prefix = "<|system|>" @@ -96,10 +102,16 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): class PygmalionPrompter: + """ + Prompter for Pygmalion. + """ + def __init__(self, *args, **kwargs): pass - def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: + def build_prompt( + self, source, *args, **kwargs # pylint: disable=unused-argument + ) -> Generator[str, None, None]: for msg in source: yield msg["role"], msg["value"] From 5d86137f70f23ea5c1663191ae260510bdd331db Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:20:11 +0900 Subject: [PATCH 30/59] Lint prompt_tokenizers --- src/axolotl/prompt_tokenizers.py | 111 +++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 22 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index a91a4e2d3..7febd0a72 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,7 +1,10 @@ +"""Module containing PromptTokenizingStrategy and Prompter classes""" + import abc import copy import functools import logging +from typing import Tuple from transformers import PreTrainedTokenizer @@ -15,10 +18,16 @@ LLAMA_DEFAULT_UNK_TOKEN = "" class InvalidDataException(Exception): - pass + """ + Exception raised when the data is invalid + """ class PromptTokenizingStrategy(abc.ABC): + """ + Abstract class for tokenizing strategies + """ + def __init__( self, prompter, @@ -35,14 +44,14 @@ class PromptTokenizingStrategy(abc.ABC): def tokenize_prompt(self, prompt): pass - @functools.cache + @functools.lru_cache(maxsize=128) def _get_user_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") if isinstance(id_or_ids, (int,)): return id_or_ids return False - @functools.cache + @functools.lru_cache(maxsize=128) def _get_assistant_token(self): id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") if isinstance(id_or_ids, (int,)): @@ -51,11 +60,19 @@ class PromptTokenizingStrategy(abc.ABC): class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for instruction-based prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: raise NotImplementedError def tokenize_prompt(self, prompt): - instruction, input, response = self.parse_instruction_fields(prompt) + ( + instruction, + input, # pylint: disable=redefined-builtin + response, + ) = self.parse_instruction_fields(prompt) full_prompt = self._build_full_prompt(instruction, input, response) tokenized_full_prompt = self._tokenize(full_prompt) if not self.train_on_inputs: @@ -76,7 +93,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, response): + def _build_full_prompt( + self, instruction, input, response # pylint: disable=redefined-builtin + ): return next( iter( self.prompter.build_prompt( @@ -112,7 +131,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Alpaca prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Alpaca Multiple Choice prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], "\n".join(f'- "{choice}"' for choice in prompt["choices"]), @@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for Jeopardy prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["question"], prompt["category"], @@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for OpenAssistant prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["INSTRUCTION"], "", @@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy) class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for SummarizeTLDR prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["article"], "", @@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy) class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for GPTeacher prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str): + """ + Tokenizing strategy for NomicGPT4All prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: return ( prompt["prompt"], "", @@ -175,6 +222,10 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + """ + Tokenizing strategy for Completion prompts. + """ + def parse_instruction_fields(self, prompt) -> str: return prompt["text"] @@ -185,18 +236,24 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, response): + def _build_full_prompt( + self, instruction, input, response + ): # pylint: disable=unused-argument, redefined-builtin return next(iter(self.prompter.build_prompt(instruction))) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): + """ + Tokenizing strategy for Reflection prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: raise NotImplementedError def tokenize_prompt(self, prompt): ( instruction, - input, + input, # pylint: disable=redefined-builtin output, reflection, corrected, @@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): return tokenized_full_prompt - def _build_full_prompt(self, instruction, input, output, reflection, corrected): + def _build_full_prompt( + self, instruction, input, output, reflection, corrected + ): # pylint: disable=redefined-builtin return next( iter( self.prompter.build_prompt( @@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): - def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): + """ + Tokenizing strategy for Alpaca Reflection prompts. + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]: return ( prompt["instruction"], prompt["input"] if "input" in prompt else "", @@ -268,6 +331,10 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): + """ + Tokenizing strategy for ShareGPT prompts. + """ + def get_conversation_thread(self, prompt): return prompt["conversations"] @@ -281,7 +348,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): user_token = self._get_user_token() assistant_token = self._get_assistant_token() try: - for i, part in enumerate( + for _, part in enumerate( self.prompter.build_prompt(self.get_conversation_thread(prompt)) ): if isinstance(part, tuple): @@ -307,7 +374,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): # not masked out from labels labels = copy.deepcopy(res["input_ids"]) else: - logging.warning("unhandled role: " + part[0]) + logging.warning(f"unhandled role: {part[0]}") else: # this is only ever the first part, should include the bos token and the user query res = self._tokenize( @@ -324,8 +391,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): result["labels"][current_len : current_len + input_len] = labels current_len += input_len return result - except (KeyError, AssertionError, IndexError) as e: - raise InvalidDataException(str(e)) + except (KeyError, AssertionError, IndexError) as err: + raise InvalidDataException(str(err)) from err def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( From 633ff2150fd65e0d19bcb66389eb8f3e8c87ea12 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:21:14 +0900 Subject: [PATCH 31/59] Lint dict --- src/axolotl/utils/dict.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py index e3a0a517d..375baf0ea 100644 --- a/src/axolotl/utils/dict.py +++ b/src/axolotl/utils/dict.py @@ -1,3 +1,5 @@ +"""Module containing the DictDefault class""" + from addict import Dict From dae14e5951ec3cc8f8acd6d8b9248c9a3e3e31af Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:23:23 +0900 Subject: [PATCH 32/59] Ignore too-many-instance-attributes --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index 982bad52d..ed973d285 100644 --- a/.pylintrc +++ b/.pylintrc @@ -11,4 +11,4 @@ generated-members=numpy.*, torch.* [pylint.messages_control] disable=missing-function-docstring, line-too-long, import-error, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, - fixme, import-outside-toplevel, logging-fstring-interpolation, + too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, From fe1f4c4e7d3eb7bc412de35831a0466f56d70db5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:25:15 +0900 Subject: [PATCH 33/59] Lint schedulers --- src/axolotl/utils/schedulers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b9b7e25be..f9b9e3583 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,7 +1,13 @@ +"""Module for custom LRScheduler class""" + from torch.optim.lr_scheduler import LRScheduler class InterpolatingLogScheduler(LRScheduler): + """ + A scheduler that interpolates learning rates in a logarithmic fashion + """ + def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1): """A scheduler that interpolates learning rates in a logarithmic fashion @@ -19,7 +25,9 @@ class InterpolatingLogScheduler(LRScheduler): self.num_steps = num_steps self.min_lr = min_lr self.max_lr = max_lr - self.q = (max_lr / min_lr) ** (1 / (num_steps - 1)) + self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name + 1 / (num_steps - 1) + ) super().__init__(optimizer, last_epoch) def get_lr(self): From e6b57decbd559dce82bcc39817e668fc9bc2e09e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:26:12 +0900 Subject: [PATCH 34/59] Lint tokenization --- src/axolotl/utils/tokenization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index f23ca8a92..159dbe15d 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,5 +1,8 @@ -from termcolor import colored +"""Module for tokenization utilities""" + + import logging +from termcolor import colored def check_dataset_labels(dataset, tokenizer): @@ -17,7 +20,7 @@ def check_example_labels(example, tokenizer): # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 colored_tokens = [] - for i, (input_id, label_id, mask) in enumerate( + for _, (input_id, label_id, mask) in enumerate( zip(input_ids, labels, attention_mask) ): decoded_input_token = tokenizer.decode(input_id) From c2dbf2c526689bd6a7839c232e80be42dae1cd98 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:26:43 +0900 Subject: [PATCH 35/59] Lint validation --- src/axolotl/utils/validation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index bc2940d5e..f51640686 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,3 +1,5 @@ +"""Module for validating config files""" + import logging From 9c6750a075ece3b79d3ea2daef9a5b21b33ef643 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:27:08 +0900 Subject: [PATCH 36/59] Lint wandb --- src/axolotl/utils/wandb.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py index 992bb1a5f..90e9c2f73 100644 --- a/src/axolotl/utils/wandb.py +++ b/src/axolotl/utils/wandb.py @@ -1,3 +1,5 @@ +"""Module for wandb utilities""" + import os From 0e952889dcc2336a2cf24ccb42aefc942f1b3055 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:28:38 +0900 Subject: [PATCH 37/59] Lint test_dict --- tests/test_dict.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/test_dict.py b/tests/test_dict.py index 81a528fe4..ad4c73480 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -1,3 +1,6 @@ +"""Module for testing DictDefault class""" + + import unittest import pytest @@ -6,6 +9,10 @@ from axolotl.utils.dict import DictDefault class DictDefaultTest(unittest.TestCase): + """ + Test DictDefault class + """ + def test_dict_default(self): cfg = DictDefault( { @@ -73,7 +80,7 @@ class DictDefaultTest(unittest.TestCase): AttributeError, match=r"'NoneType' object has no attribute 'another_random_key'", ): - cfg.random_key.another_random_key + cfg.random_key.another_random_key = "value" def test_dict_shorthand_assignment(self): """ From 1f3c3f5ea0aaeb86e353993d7c8ef54cd8009f59 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 14:29:19 +0900 Subject: [PATCH 38/59] Lint validation --- tests/test_validation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_validation.py b/tests/test_validation.py index af38eb6af..f92e6c6cd 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,3 +1,5 @@ +"""Module for testing the validation module""" + import unittest import pytest @@ -7,6 +9,10 @@ from axolotl.utils.dict import DictDefault class ValidationTest(unittest.TestCase): + """ + Test the validation module + """ + def test_load_4bit_deprecate(self): cfg = DictDefault( { From 8e46c0fb0ddd1dbeb4f31a542ae18d873563192f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:08:26 +0900 Subject: [PATCH 39/59] Refactor duplicate code between Prompter and Pygmalion --- src/axolotl/prompt_strategies/pygmalion.py | 51 +++------- src/axolotl/prompt_tokenizers.py | 111 +++++++++++++-------- 2 files changed, 86 insertions(+), 76 deletions(-) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 01828a034..4cd9a1685 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -5,7 +5,11 @@ import logging from collections import defaultdict from typing import Generator -from axolotl.prompt_tokenizers import PromptTokenizingStrategy +from axolotl.prompt_tokenizers import ( + PromptTokenizingStrategy, + parse_tokenized_to_result, + tokenize_prompt_default, +) IGNORE_TOKEN_ID = -100 @@ -23,12 +27,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): self.bot_prefix_token_ids = res["input_ids"] def tokenize_prompt(self, prompt): - result = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - current_len = 0 + result, current_len = tokenize_prompt_default() for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): role, message = part if role == "system": @@ -67,37 +66,15 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): else: logging.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) - input_ids = res["input_ids"] - input_len = len(input_ids) - result["input_ids"][current_len : current_len + input_len] = input_ids - result["attention_mask"][current_len : current_len + input_len] = [ - 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids - ] - result["labels"][current_len : current_len + input_len] = labels - current_len += input_len - return result - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) return result diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7febd0a72..ceb65e2ab 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -4,7 +4,7 @@ import abc import copy import functools import logging -from typing import Tuple +from typing import Dict, List, Tuple from transformers import PreTrainedTokenizer @@ -58,6 +58,29 @@ class PromptTokenizingStrategy(abc.ABC): return id_or_ids return False + def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): """ @@ -106,29 +129,6 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): ) ) - def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): - result = self.tokenizer( - prompt, - truncation=True, - max_length=self.sequence_len, - padding=False, - return_tensors=None, - ) - if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id - and len(result["input_ids"]) < self.sequence_len - and add_eos_token - ): - result["input_ids"].append(self.tokenizer.eos_token_id) - result["attention_mask"].append(1) - - if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: - result["input_ids"] = result["input_ids"][1:] - result["attention_mask"] = result["attention_mask"][1:] - - result["labels"] = result["input_ids"].copy() - return result - class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): """ @@ -295,7 +295,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): ) ) - def _tokenize(self, prompt, add_eos_token=True): + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( prompt, truncation=True, @@ -339,12 +339,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): return prompt["conversations"] def tokenize_prompt(self, prompt): - result = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - current_len = 0 + result, current_len = tokenize_prompt_default() user_token = self._get_user_token() assistant_token = self._get_assistant_token() try: @@ -382,14 +377,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): ) # everything from this is masked out from the labels labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - input_ids = res["input_ids"] - input_len = len(input_ids) - result["input_ids"][current_len : current_len + input_len] = input_ids - result["attention_mask"][current_len : current_len + input_len] = [ - 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids - ] - result["labels"][current_len : current_len + input_len] = labels - current_len += input_len + + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) return result except (KeyError, AssertionError, IndexError) as err: raise InvalidDataException(str(err)) from err @@ -416,3 +412,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): result["labels"] = result["input_ids"].copy() return result + + +def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: + """ + Returns the default values for the tokenize prompt function + """ + + result = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + current_len = 0 + return result, current_len + + +def parse_tokenized_to_result( + result: Dict[str, List[int]], + current_len: int, + res: Dict[str, List[int]], + labels: list[int], + pad_token_id: int | None = None, +) -> Tuple[Dict[str, List[int]], int]: + """ + Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result + """ + + input_ids = res["input_ids"] + input_len = len(input_ids) + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ + 1 if x != pad_token_id else 0 for x in input_ids + ] + result["labels"][current_len : current_len + input_len] = labels + current_len += input_len + + return result, current_len From 1bf1f59a41cd9f8ec5a8868849733fb334ddd3d7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:24:40 +0900 Subject: [PATCH 40/59] Move black to dev requirements --- requirements-dev.txt | 1 + requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 416634f52..20420596a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1 +1,2 @@ pre-commit +black diff --git a/requirements.txt b/requirements.txt index 27b31a139..20a5feb42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ bitsandbytes>=0.39.0 addict fire PyYAML==6.0 -black datasets accelerate>=0.19.0 sentencepiece From afb31e13a3cd6598835009bae19a6a0a2483d136 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:24:54 +0900 Subject: [PATCH 41/59] Add badge and update contribution section --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index adc3c5812..c970e6c48 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@

Go ahead and axolotl questions!!

+ pre-commit + PyTest Status @@ -406,3 +408,9 @@ Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new). PRs are **greatly welcome**! + +Please run below to setup env +```bash +pip3 install -r requirements-dev.txt +pre-commit install +``` From b832a0ac62e5b6523712e085268d684ea0f56710 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:30:28 +0900 Subject: [PATCH 42/59] Black formatting --- src/axolotl/utils/data.py | 7 +++++-- src/axolotl/utils/validation.py | 4 +++- tests/test_validation.py | 1 - 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 3164f2ecc..7b718bf56 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -107,7 +107,9 @@ def load_tokenized_prepared_datasets( use_auth_token=use_auth_token, ) else: - ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token) + ds: Dataset = load_dataset( + d.path, streaming=False, use_auth_token=use_auth_token + ) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files @@ -293,7 +295,8 @@ def load_prepare_datasets( f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token + f"{cfg.push_dataset_to_hub}/{ds_hash}", + use_auth_token=use_auth_token, ) dataset = dataset["train"] except Exception: # pylint: disable=broad-except diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index f51640686..c4bc4f952 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -40,7 +40,9 @@ def validate_config(cfg): ) if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: - raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") + raise ValueError( + "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" + ) # TODO # MPT 7b diff --git a/tests/test_validation.py b/tests/test_validation.py index f92e6c6cd..210e6eb20 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -117,4 +117,3 @@ class ValidationTest(unittest.TestCase): } ) validate_config(cfg) - From be22551435ed644b6b746058069afc2edbbbf334 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:33:40 +0900 Subject: [PATCH 43/59] Fix unsupported operand type(s) for | --- src/axolotl/prompt_tokenizers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index ceb65e2ab..761441a7e 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -4,7 +4,7 @@ import abc import copy import functools import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union from transformers import PreTrainedTokenizer @@ -433,7 +433,7 @@ def parse_tokenized_to_result( current_len: int, res: Dict[str, List[int]], labels: list[int], - pad_token_id: int | None = None, + pad_token_id: Union[int, None] = None, ) -> Tuple[Dict[str, List[int]], int]: """ Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result From db288e9b13b10502bb4e113070f338a0b4eefad9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 15:51:32 +0900 Subject: [PATCH 44/59] Set python version --- .pre-commit-config.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ea958b7f6..f51dbc6d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +default_language_version: + python: python3.9 + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 From 0dd35c74af9217c3b89c6bc8a46eeac74855169b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 16:54:19 +0900 Subject: [PATCH 45/59] Ignore unsupported-binary-operation --- tests/test_dict.py | 4 +++- tests/test_validation.py | 14 +++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/test_dict.py b/tests/test_dict.py index ad4c73480..4852707fb 100644 --- a/tests/test_dict.py +++ b/tests/test_dict.py @@ -48,7 +48,9 @@ class DictDefaultTest(unittest.TestCase): } ) - cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"}) + cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation + {"key_a": {"key_b": "value_b"}, "key_f": "value_g"} + ) assert ( cfg.key_a.key_b == "value_b" diff --git a/tests/test_validation.py b/tests/test_validation.py index 210e6eb20..a92666001 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -30,7 +30,7 @@ class ValidationTest(unittest.TestCase): } ) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_8bit": True, } @@ -39,7 +39,7 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "gptq": True, } @@ -48,7 +48,7 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": False, } @@ -57,7 +57,7 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": True, } @@ -73,7 +73,7 @@ class ValidationTest(unittest.TestCase): } ) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_8bit": True, } @@ -82,7 +82,7 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "gptq": True, } @@ -91,7 +91,7 @@ class ValidationTest(unittest.TestCase): with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) - cfg = base_cfg | DictDefault( + cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": True, } From 741a3f2edcc8cb19cad04fffc449e5fd96c86831 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 17:35:51 +0900 Subject: [PATCH 46/59] Add mypy --- .mypy.ini | 3 +++ .pre-commit-config.yaml | 9 +++++++++ requirements-dev.txt | 1 + 3 files changed, 13 insertions(+) create mode 100644 .mypy.ini diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 000000000..486dde3fe --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,3 @@ +[mypy] + +exclude = venv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f51dbc6d3..46e29e7a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,3 +20,12 @@ repos: rev: v2.17.4 hooks: - id: pylint +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.3.0 + hooks: + - id: mypy + additional_dependencies: + [ + 'fire', + 'types-PyYAML' + ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 20420596a..df7e312cb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,3 @@ pre-commit black +mypy From f1232b35ba00d043f316a916f8cbb4bf35714ad8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:04:17 +0900 Subject: [PATCH 47/59] Update mypy dependencies --- .mypy.ini | 30 ++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 3 +-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 486dde3fe..941046ae8 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,3 +1,33 @@ [mypy] exclude = venv + +[mypy-alpaca_lora_4bit.*] +ignore_missing_imports = True + +[mypy-flash_attn.*] +ignore_missing_imports = True + +[mypy-huggingface_hub] +ignore_missing_imports = True + +[mypy-transformers.*] +ignore_missing_imports = True + +[mypy-peft] +ignore_missing_imports = True + +[mypy-bitsandbytes] +ignore_missing_imports = True + +[mypy-datasets] +ignore_missing_imports = True + +[mypy-fire] +ignore_missing_imports = True + +[mypy-setuptools] +ignore_missing_imports = True + +[mypy-addict] +ignore_missing_imports = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46e29e7a0..c578dbc67 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,5 @@ repos: - id: mypy additional_dependencies: [ - 'fire', - 'types-PyYAML' + 'types-PyYAML', ] From e9650d3ae471551acdca4c53f8e920efd3aa5167 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:13:39 +0900 Subject: [PATCH 48/59] Fix mypy typing --- scripts/alpaca_json_to_jsonl.py | 3 +- scripts/extract_lora.py | 163 +++++++++++++++++++++ src/axolotl/prompt_strategies/pygmalion.py | 6 +- src/axolotl/prompt_tokenizers.py | 12 +- src/axolotl/prompters.py | 14 +- src/axolotl/utils/data.py | 20 +-- src/axolotl/utils/models.py | 2 +- src/axolotl/utils/trainer.py | 3 +- 8 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 scripts/extract_lora.py diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index f535d1afc..2f56c07b3 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -3,7 +3,7 @@ import os import sys -from typing import Optional +from typing import Optional, Union from pathlib import Path import fire @@ -35,6 +35,7 @@ def main( """ file_reader = FileReader() + writer: Union[StdoutWriter, FileWriter] if to_stdout or output is None: writer = StdoutWriter() else: diff --git a/scripts/extract_lora.py b/scripts/extract_lora.py new file mode 100644 index 000000000..be88c5705 --- /dev/null +++ b/scripts/extract_lora.py @@ -0,0 +1,163 @@ +# import logging +# import os +# import random +# import signal +# import sys +# from pathlib import Path + +# import fire +# import torch +# import yaml +# from addict import Dict + +# from peft import set_peft_model_state_dict, get_peft_model_state_dict + +# # add src to the pythonpath so we don't need to pip install this +# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +# src_dir = os.path.join(project_root, "src") +# sys.path.insert(0, src_dir) + +# from axolotl.utils.data import load_prepare_datasets +# from axolotl.utils.models import load_model +# from axolotl.utils.trainer import setup_trainer +# from axolotl.utils.wandb import setup_wandb_env_vars + +# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) + + +# def choose_device(cfg): +# def get_device(): +# if torch.cuda.is_available(): +# return "cuda" +# else: +# try: +# if torch.backends.mps.is_available(): +# return "mps" +# except: +# return "cpu" + +# cfg.device = get_device() +# if cfg.device == "cuda": +# cfg.device_map = {"": cfg.local_rank} +# else: +# cfg.device_map = {"": cfg.device} + + +# def choose_config(path: Path): +# yaml_files = [file for file in path.glob("*.yml")] + +# if not yaml_files: +# raise ValueError( +# "No YAML config files found in the specified directory. Are you using a .yml extension?" +# ) + +# print("Choose a YAML file:") +# for idx, file in enumerate(yaml_files): +# print(f"{idx + 1}. {file}") + +# chosen_file = None +# while chosen_file is None: +# try: +# choice = int(input("Enter the number of your choice: ")) +# if 1 <= choice <= len(yaml_files): +# chosen_file = yaml_files[choice - 1] +# else: +# print("Invalid choice. Please choose a number from the list.") +# except ValueError: +# print("Invalid input. Please enter a number.") + +# return chosen_file + + +# def save_latest_checkpoint_as_lora( +# config: Path = Path("configs/"), +# prepare_ds_only: bool = False, +# **kwargs, +# ): +# if Path(config).is_dir(): +# config = choose_config(config) + +# # load the config from the yaml file +# with open(config, "r") as f: +# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) +# # if there are any options passed in the cli, if it is something that seems valid from the yaml, +# # then overwrite the value +# cfg_keys = dict(cfg).keys() +# for k in kwargs: +# if k in cfg_keys: +# # handle booleans +# if isinstance(cfg[k], bool): +# cfg[k] = bool(kwargs[k]) +# else: +# cfg[k] = kwargs[k] + +# # setup some derived config / hyperparams +# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size +# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) +# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) +# assert cfg.local_rank == 0, "Run this with only one device!" + +# choose_device(cfg) +# cfg.ddp = False + +# if cfg.device == "mps": +# cfg.load_in_8bit = False +# cfg.tf32 = False +# if cfg.bf16: +# cfg.fp16 = True +# cfg.bf16 = False + +# # Load the model and tokenizer +# logging.info("loading model, tokenizer, and lora_config...") +# model, tokenizer, lora_config = load_model( +# cfg.base_model, +# cfg.base_model_config, +# cfg.model_type, +# cfg.tokenizer_type, +# cfg, +# adapter=cfg.adapter, +# inference=True, +# ) + +# model.config.use_cache = False + +# if torch.__version__ >= "2" and sys.platform != "win32": +# logging.info("Compiling torch model") +# model = torch.compile(model) + +# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] +# if len(possible_checkpoints) > 0: +# sorted_paths = sorted( +# possible_checkpoints, key=lambda path: int(path.split("-")[-1]) +# ) +# resume_from_checkpoint = sorted_paths[-1] +# else: +# raise FileNotFoundError("Checkpoints folder not found") + +# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin") + +# assert os.path.exists(pytorch_bin_path), "Bin not found" + +# logging.info(f"Loading {pytorch_bin_path}") +# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu") + +# # d = get_peft_model_state_dict(model) +# print(model.load_state_dict(adapters_weights)) +# # with open('b.log', "w") as f: +# # f.write(str(d.keys())) +# assert False + +# print((adapters_weights.keys())) +# with open("a.log", "w") as f: +# f.write(str(adapters_weights.keys())) +# assert False + +# logging.info("Setting peft model state dict") +# set_peft_model_state_dict(model, adapters_weights) + +# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}") +# model.save_pretrained(cfg.output_dir) + + +# if __name__ == "__main__": +# fire.Fire(save_latest_checkpoint_as_lora) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 4cd9a1685..d38bc2beb 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -3,7 +3,7 @@ import copy import logging from collections import defaultdict -from typing import Generator +from typing import Generator, List, Tuple from axolotl.prompt_tokenizers import ( PromptTokenizingStrategy, @@ -19,7 +19,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for Pygmalion. """ - bot_prefix_token_ids = [] + bot_prefix_token_ids: List[int] = [] def __init__(self, prompter, tokenizer, *args, **kwargs): super().__init__(prompter, tokenizer, *args, **kwargs) @@ -88,7 +88,7 @@ class PygmalionPrompter: def build_prompt( self, source, *args, **kwargs # pylint: disable=unused-argument - ) -> Generator[str, None, None]: + ) -> Generator[Tuple[str, str], None, None]: for msg in source: yield msg["role"], msg["value"] diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 761441a7e..d1655da32 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -226,20 +226,16 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): Tokenizing strategy for Completion prompts. """ - def parse_instruction_fields(self, prompt) -> str: - return prompt["text"] - def tokenize_prompt(self, prompt): - instruction = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction, None, None) + full_prompt = self._build_full_prompt(prompt["text"], None, None) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt def _build_full_prompt( self, instruction, input, response - ): # pylint: disable=unused-argument, redefined-builtin - return next(iter(self.prompter.build_prompt(instruction))) + ): # pylint: disable=redefined-builtin + return next(iter(self.prompter.build_prompt(instruction, input, response))) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): @@ -419,7 +415,7 @@ def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: Returns the default values for the tokenize prompt function """ - result = { + result: Dict[str, List[int]] = { "input_ids": [], "attention_mask": [], "labels": [], diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index eced1d4a5..97c2e3454 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -3,7 +3,7 @@ import dataclasses import logging from enum import auto, Enum -from typing import List, Union, Generator +from typing import List, Optional, Union, Generator IGNORE_TOKEN_ID = -100 @@ -24,7 +24,7 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" - prompt_style = None + prompt_style: Optional[PromptStyle] = None def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value @@ -231,18 +231,18 @@ class Conversation: offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" - sep2: str = None + sep2: Optional[str] = None def get_prompt(self) -> Generator[str, None, None]: - seps = [self.sep, self.sep2] - preamble = self.system + seps[0] + # seps = [self.sep, self.sep2] + preamble = self.system + self.sep yield preamble for _, (role, message) in enumerate(self.messages): if message: - yield (role + ":", " " + message) + yield role + ":" + " " + message else: logging.warning(f"role with empty message: {role}") - yield (role + ":",) + yield role + ":" def copy(self): return Conversation( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 7b718bf56..74812f9a0 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import logging from hashlib import md5 from pathlib import Path -from typing import Tuple, Union +from typing import List, Tuple, Union from datasets import ( load_from_disk, @@ -95,40 +95,36 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): - ds: Dataset = load_dataset( + ds = load_dataset( "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds: Dataset = load_dataset( + ds = load_dataset( d.path, streaming=False, data_files=d.data_files, use_auth_token=use_auth_token, ) else: - ds: Dataset = load_dataset( + ds = load_dataset( d.path, streaming=False, use_auth_token=use_auth_token ) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds: Dataset = load_dataset( - "json", data_files=fp, streaming=False, split=None - ) + ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds: DatasetDict = ds.shuffle(seed=42)["train"].shard( + ds = ds.shuffle(seed=42)["train"].shard( num_shards=d.shards, index=0 ) else: - ds: Dataset = ds.shuffle(seed=42).shard( - num_shards=d.shards, index=0 - ) + ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] @@ -232,7 +228,7 @@ def load_tokenized_prepared_datasets( logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.info("tokenizing, merging, and shuffling master dataset") - samples = [] + samples: List[int] = [] for d in datasets: samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=42) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7a81b8a49..5cdfaab3c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -81,7 +81,7 @@ def load_model( adapter="lora", inference=False, ): - # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ Load a model from a base model and a model type. """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 299e39664..45f13e530 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import math import os import sys from pathlib import Path +from typing import Optional import bitsandbytes as bnb import torch.cuda @@ -28,7 +29,7 @@ class OneCycleLRSchedulerTrainer(Trainer): self.lr_scheduler = None def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps) From 96e8378692b0f84c29e7725b550cf30d79f09fe1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:14:33 +0900 Subject: [PATCH 49/59] Delete extract_lora.py --- scripts/extract_lora.py | 163 ---------------------------------------- 1 file changed, 163 deletions(-) delete mode 100644 scripts/extract_lora.py diff --git a/scripts/extract_lora.py b/scripts/extract_lora.py deleted file mode 100644 index be88c5705..000000000 --- a/scripts/extract_lora.py +++ /dev/null @@ -1,163 +0,0 @@ -# import logging -# import os -# import random -# import signal -# import sys -# from pathlib import Path - -# import fire -# import torch -# import yaml -# from addict import Dict - -# from peft import set_peft_model_state_dict, get_peft_model_state_dict - -# # add src to the pythonpath so we don't need to pip install this -# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -# src_dir = os.path.join(project_root, "src") -# sys.path.insert(0, src_dir) - -# from axolotl.utils.data import load_prepare_datasets -# from axolotl.utils.models import load_model -# from axolotl.utils.trainer import setup_trainer -# from axolotl.utils.wandb import setup_wandb_env_vars - -# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) - - -# def choose_device(cfg): -# def get_device(): -# if torch.cuda.is_available(): -# return "cuda" -# else: -# try: -# if torch.backends.mps.is_available(): -# return "mps" -# except: -# return "cpu" - -# cfg.device = get_device() -# if cfg.device == "cuda": -# cfg.device_map = {"": cfg.local_rank} -# else: -# cfg.device_map = {"": cfg.device} - - -# def choose_config(path: Path): -# yaml_files = [file for file in path.glob("*.yml")] - -# if not yaml_files: -# raise ValueError( -# "No YAML config files found in the specified directory. Are you using a .yml extension?" -# ) - -# print("Choose a YAML file:") -# for idx, file in enumerate(yaml_files): -# print(f"{idx + 1}. {file}") - -# chosen_file = None -# while chosen_file is None: -# try: -# choice = int(input("Enter the number of your choice: ")) -# if 1 <= choice <= len(yaml_files): -# chosen_file = yaml_files[choice - 1] -# else: -# print("Invalid choice. Please choose a number from the list.") -# except ValueError: -# print("Invalid input. Please enter a number.") - -# return chosen_file - - -# def save_latest_checkpoint_as_lora( -# config: Path = Path("configs/"), -# prepare_ds_only: bool = False, -# **kwargs, -# ): -# if Path(config).is_dir(): -# config = choose_config(config) - -# # load the config from the yaml file -# with open(config, "r") as f: -# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) -# # if there are any options passed in the cli, if it is something that seems valid from the yaml, -# # then overwrite the value -# cfg_keys = dict(cfg).keys() -# for k in kwargs: -# if k in cfg_keys: -# # handle booleans -# if isinstance(cfg[k], bool): -# cfg[k] = bool(kwargs[k]) -# else: -# cfg[k] = kwargs[k] - -# # setup some derived config / hyperparams -# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size -# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) -# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) -# assert cfg.local_rank == 0, "Run this with only one device!" - -# choose_device(cfg) -# cfg.ddp = False - -# if cfg.device == "mps": -# cfg.load_in_8bit = False -# cfg.tf32 = False -# if cfg.bf16: -# cfg.fp16 = True -# cfg.bf16 = False - -# # Load the model and tokenizer -# logging.info("loading model, tokenizer, and lora_config...") -# model, tokenizer, lora_config = load_model( -# cfg.base_model, -# cfg.base_model_config, -# cfg.model_type, -# cfg.tokenizer_type, -# cfg, -# adapter=cfg.adapter, -# inference=True, -# ) - -# model.config.use_cache = False - -# if torch.__version__ >= "2" and sys.platform != "win32": -# logging.info("Compiling torch model") -# model = torch.compile(model) - -# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] -# if len(possible_checkpoints) > 0: -# sorted_paths = sorted( -# possible_checkpoints, key=lambda path: int(path.split("-")[-1]) -# ) -# resume_from_checkpoint = sorted_paths[-1] -# else: -# raise FileNotFoundError("Checkpoints folder not found") - -# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin") - -# assert os.path.exists(pytorch_bin_path), "Bin not found" - -# logging.info(f"Loading {pytorch_bin_path}") -# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu") - -# # d = get_peft_model_state_dict(model) -# print(model.load_state_dict(adapters_weights)) -# # with open('b.log', "w") as f: -# # f.write(str(d.keys())) -# assert False - -# print((adapters_weights.keys())) -# with open("a.log", "w") as f: -# f.write(str(adapters_weights.keys())) -# assert False - -# logging.info("Setting peft model state dict") -# set_peft_model_state_dict(model, adapters_weights) - -# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}") -# model.save_pretrained(cfg.output_dir) - - -# if __name__ == "__main__": -# fire.Fire(save_latest_checkpoint_as_lora) From 37293dce07a36f31d3d7f7c2a39f238c8c2a29a0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:48:58 +0900 Subject: [PATCH 50/59] Apply isort then black --- scripts/alpaca_json_to_jsonl.py | 7 +-- scripts/finetune.py | 20 +++--- setup.py | 2 +- src/axolotl/datasets.py | 8 ++- src/axolotl/flash_attn.py | 42 +++++++++---- src/axolotl/prompt_strategies/alpaca_chat.py | 1 + src/axolotl/prompt_strategies/creative_acr.py | 34 +++++++--- src/axolotl/prompt_tokenizers.py | 13 +++- src/axolotl/prompters.py | 8 ++- src/axolotl/utils/callbacks.py | 7 ++- src/axolotl/utils/data.py | 63 +++++++++++-------- src/axolotl/utils/models.py | 36 ++++------- src/axolotl/utils/tokenization.py | 1 + src/axolotl/utils/trainer.py | 11 +++- tests/test_validation.py | 2 +- 15 files changed, 158 insertions(+), 97 deletions(-) diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index 2f56c07b3..61cb170ec 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -2,23 +2,20 @@ import os import sys - -from typing import Optional, Union from pathlib import Path +from typing import Optional, Union import fire - from axolotl.convert import ( FileReader, - StdoutWriter, FileWriter, JsonlSerializer, JsonParser, JsonToJsonlConverter, + StdoutWriter, ) - # add src to the pythonpath so we don't need to pip install this project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") diff --git a/scripts/finetune.py b/scripts/finetune.py index 226068020..4716744b2 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -7,20 +7,20 @@ import random import signal import sys from pathlib import Path -from typing import Optional, List, Dict, Any, Union +from typing import Any, Dict, List, Optional, Union import fire import torch import yaml +from axolotl.utils.data import load_prepare_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.validation import validate_config -from axolotl.utils.dict import DictDefault - -from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer +from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -242,7 +242,10 @@ def train( if cfg.local_rank == 0: signal.signal( signal.SIGINT, - lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)), + lambda signal, frame: ( + model.save_pretrained(cfg.output_dir), + sys.exit(0), + ), ) logging.info("Starting trainer...") @@ -255,7 +258,8 @@ def train( ] if len(possible_checkpoints) > 0: sorted_paths = sorted( - possible_checkpoints, key=lambda path: int(path.split("-")[-1]) + possible_checkpoints, + key=lambda path: int(path.split("-")[-1]), ) resume_from_checkpoint = sorted_paths[-1] logging.info( diff --git a/setup.py b/setup.py index 7f51f495f..de9fdc62f 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ """setup.py for axolotl""" -from setuptools import setup, find_packages +from setuptools import find_packages, setup install_requires = [] with open("./requirements.txt", encoding="utf-8") as requirements_file: diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 1e72be114..fb5e15656 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -5,8 +5,8 @@ from typing import List import torch from datasets import IterableDataset -from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException +from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded # lets use the concept of middlewares to wrap each dataset, for example @@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset): logging.warning( f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" ) - buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } buffer_len = 0 if example: diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index c7bd12c66..6df0b8e18 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -5,14 +5,11 @@ from typing import Optional, Tuple import torch - import transformers -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - from einops import rearrange - +from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb def forward( @@ -75,7 +72,11 @@ def forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len cu_q_lens = torch.arange( - 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + 0, + (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=qkv.device, ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True @@ -88,25 +89,44 @@ def forward( x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( - x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + x_unpad, + "nnz (three h d) -> nnz three h d", + three=3, + h=nheads, ) output_unpad = flash_attn_unpadded_qkvpacked_func( - x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + x_unpad, + cu_q_lens, + max_s, + 0.0, + softmax_scale=None, + causal=True, ) output = rearrange( pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + rearrange(output_unpad, "nnz h d -> nnz (h d)"), + indices, + bsz, + q_len, ), "b s (h d) -> b s h d", h=nheads, ) - return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None + return ( + self.o_proj(rearrange(output, "b s h d -> b s (h d)")), + None, + None, + ) # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length + self, + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, ): # pylint: disable=unused-argument # [bsz, seq_len] return attention_mask diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 29a0cb654..15dfb65c4 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,6 +1,7 @@ """Module containing the AlpacaQAPromptTokenizingStrategy class""" from typing import Tuple + from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py index 5cf89127d..ea67034b3 100644 --- a/src/axolotl/prompt_strategies/creative_acr.py +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -1,8 +1,9 @@ """Module loading the CreativePromptTokenizingStrategy and similar classes""" -from typing import Tuple, Union, Generator +from typing import Generator, Tuple, Union import yaml + from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy @@ -61,10 +62,14 @@ Answer: {answer} def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( - prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper + prompt["scores"], + default_flow_style=False, + Dumper=yaml.Dumper, ) critiques = yaml.dump( - prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper + prompt["critiques"], + default_flow_style=False, + Dumper=yaml.Dumper, ) evaluation = scores + critiques question = prompt["instruction"] @@ -97,10 +102,14 @@ Evaluation: def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( - prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper + prompt["scores"], + default_flow_style=False, + Dumper=yaml.Dumper, ) critiques = yaml.dump( - prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper + prompt["critiques"], + default_flow_style=False, + Dumper=yaml.Dumper, ) evaluation = scores + critiques question = prompt["instruction"] @@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase): def load_answer(tokenizer, cfg): return CreativeAnsweringPromptTokenizingStrategy( - CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeAnswerPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) def load_critique(tokenizer, cfg): return CreativeCritiquePromptTokenizingStrategy( - CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeCritiquePrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) def load_revise(tokenizer, cfg): return CreativeRevisePromptTokenizingStrategy( - CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeRevisePrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index d1655da32..3acae91b8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): part = part[0] + part[1] if not user_token else part[1] # this is still the user query, we should res = self._tokenize( - part.strip(), add_eos_token=False, strip_bos_token=True + part.strip(), + add_eos_token=False, + strip_bos_token=True, ) if user_token: res["input_ids"] = [user_token, *res["input_ids"]] @@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): part = part[0] + part[1] if not assistant_token else part[1] # this should be the assistent response, should end with an eos token res = self._tokenize( - part.strip(), add_eos_token=True, strip_bos_token=True + part.strip(), + add_eos_token=True, + strip_bos_token=True, ) if assistant_token: - res["input_ids"] = [assistant_token, *res["input_ids"]] + res["input_ids"] = [ + assistant_token, + *res["input_ids"], + ] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) else: diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 97c2e3454..1a2535e19 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -2,8 +2,8 @@ import dataclasses import logging -from enum import auto, Enum -from typing import List, Optional, Union, Generator +from enum import Enum, auto +from typing import Generator, List, Optional, Union IGNORE_TOKEN_ID = -100 @@ -203,7 +203,9 @@ class ReflectAlpacaPrompter: res = self.prompt_no_input.format(instruction=instruction) if output and reflection and corrected: label = self.agent_label.format( - output=output, reflection=reflection, corrected=corrected + output=output, + reflection=reflection, + corrected=corrected, ) res = f"{res}{label}" yield res diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 70e83d6e4..f6852249a 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -4,9 +4,9 @@ import os from transformers import ( TrainerCallback, - TrainingArguments, - TrainerState, TrainerControl, + TrainerState, + TrainingArguments, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- **kwargs, ): checkpoint_folder = os.path.join( - args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 74812f9a0..c505cccfa 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -5,38 +5,33 @@ from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union -from datasets import ( - load_from_disk, - load_dataset, - Dataset, - DatasetDict, -) +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase -from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset +from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - GPTeacherPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, - AlpacaReflectionPTStrategy, - ShareGPTPromptTokenizingStrategy, - JeopardyPromptTokenizingStrategy, - CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy, + AlpacaPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + CompletionPromptTokenizingStrategy, + GPTeacherPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, + CompletionPrompter, GPTeacherPrompter, + JeopardyPrompter, + MultipleChoiceConcisePrompter, + MultipleChoiceExplainPrompter, ReflectAlpacaPrompter, ShareGPTPrompter, - JeopardyPrompter, - CompletionPrompter, - MultipleChoiceExplainPrompter, SummarizeTLDRPrompter, - MultipleChoiceConcisePrompter, ) @@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets( try: if cfg.push_dataset_to_hub: dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token + f"{cfg.push_dataset_to_hub}/{ds_hash}", + use_auth_token=use_auth_token, ) dataset = dataset["train"] except Exception: # pylint: disable=broad-except @@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets( ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: - load_dataset(d.path, streaming=True, use_auth_token=use_auth_token) + load_dataset( + d.path, + streaming=True, + use_auth_token=use_auth_token, + ) ds_from_hub = True except FileNotFoundError: pass @@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): ds = load_dataset( - "json", data_files=d.path, streaming=False, split=None + "json", + data_files=d.path, + streaming=False, + split=None, ) elif ds_from_hub: if d.data_files: @@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets( ) else: ds = load_dataset( - d.path, streaming=False, use_auth_token=use_auth_token + d.path, + streaming=False, + use_auth_token=use_auth_token, ) else: fp = hf_hub_download( - repo_id=d.path, repo_type="dataset", filename=d.data_files + repo_id=d.path, + repo_type="dataset", + filename=d.data_files, ) ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: @@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets( def load_prepare_datasets( - tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path + tokenizer: PreTrainedTokenizerBase, + cfg, + default_dataset_prepared_path, ) -> Tuple[Dataset, Dataset]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len @@ -353,7 +362,8 @@ def load_prepare_datasets( f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True + f"{cfg.push_dataset_to_hub}/{ds_hash}", + private=True, ) else: dataset = load_tokenized_prepared_datasets( @@ -365,7 +375,8 @@ def load_prepare_datasets( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" ) dataset = dataset.shard( - num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx + num_shards=cfg.dataset_shard_num, + index=cfg.dataset_shard_idx, ) dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5cdfaab3c..8ce39b8bc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -5,23 +5,17 @@ import logging import math import os from pathlib import Path -from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401 +from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import ( # noqa: F401 - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - AutoConfig, - BitsAndBytesConfig, -) +from transformers import AutoModelForCausalLM # noqa: F401 +from transformers import PreTrainedModel # noqa: F401 +from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig try: - from transformers import ( - LlamaForCausalLM, - ) + from transformers import LlamaForCausalLM except ImportError: logging.warning( "This version of transformers does not support Llama. Consider upgrading." @@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: from peft import PeftConfig # noqa: F401 - from axolotl.utils.dict import DictDefault # noqa: F401 from transformers import PreTrainedTokenizer # noqa: F401 + from axolotl.utils.dict import DictDefault # noqa: F401 + def load_tokenizer( base_model_config, @@ -56,7 +51,10 @@ def load_tokenizer( logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: + if tokenizer.__class__.__name__ in [ + "LlamaTokenizer", + "LlamaTokenizerFast", + ]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": @@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter): def load_llama_adapter(model, cfg): # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import ( - AdaptionPromptConfig, - get_peft_model, - PeftModel, - ) + from peft import AdaptionPromptConfig, PeftModel, get_peft_model peft_config = AdaptionPromptConfig( adapter_layers=cfg.peft_adapter.layers, # layers (L) @@ -361,11 +355,7 @@ def find_all_linear_names(bits, model): def load_lora(model, cfg): # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import ( - LoraConfig, - get_peft_model, - PeftModel, - ) + from peft import LoraConfig, PeftModel, get_peft_model lora_target_modules = list(cfg.lora_target_modules or []) diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 159dbe15d..1c535eb1b 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -2,6 +2,7 @@ import logging + from termcolor import colored diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 45f13e530..4e41d1b61 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.schedulers import InterpolatingLogScheduler class OneCycleLRSchedulerTrainer(Trainer): @@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer): self.lr_scheduler = None def create_scheduler( - self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps) @@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 + if cfg.local_rank == 0 and cfg.adapter in [ + "lora", + "qlora", + ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) data_collator_kwargs = { diff --git a/tests/test_validation.py b/tests/test_validation.py index a92666001..15bc07f84 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -4,8 +4,8 @@ import unittest import pytest -from axolotl.utils.validation import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.validation import validate_config class ValidationTest(unittest.TestCase): From c17dae6d07d174d226fdbca62d9b44cc1e66bca6 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 21:46:05 +0900 Subject: [PATCH 51/59] Update src/axolotl/prompt_strategies/alpaca_instruct.py Co-authored-by: Wing Lian --- src/axolotl/prompt_strategies/alpaca_instruct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/alpaca_instruct.py b/src/axolotl/prompt_strategies/alpaca_instruct.py index 0d0b267a6..2e42191f8 100644 --- a/src/axolotl/prompt_strategies/alpaca_instruct.py +++ b/src/axolotl/prompt_strategies/alpaca_instruct.py @@ -6,7 +6,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle def load(tokenizer, cfg): return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.INSTRUCT), + AlpacaPrompter(PromptStyle.INSTRUCT.value), tokenizer, cfg.train_on_inputs, cfg.sequence_len, From b1cc54b14a09f4e02b129d2a10ef572795e226d0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 21:49:39 +0900 Subject: [PATCH 52/59] Update pip install to also setup tests --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c970e6c48..66fdf3221 100644 --- a/README.md +++ b/README.md @@ -411,6 +411,9 @@ PRs are **greatly welcome**! Please run below to setup env ```bash -pip3 install -r requirements-dev.txt +pip3 install -r requirements-dev.txt -r requirements-tests.txt pre-commit install + +# test +pytest tests/ ``` From d0114222007f30188f9dbf9e5a828e1ba0774785 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 21:53:29 +0900 Subject: [PATCH 53/59] Add isort --- .isort.cfg | 2 ++ .pre-commit-config.yaml | 4 ++++ 2 files changed, 6 insertions(+) create mode 100644 .isort.cfg diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..b9fb3f3e8 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +profile=black diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c578dbc67..4acdba261 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,10 @@ repos: rev: 23.3.0 hooks: - id: black +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort - repo: https://github.com/PyCQA/flake8 rev: 6.0.0 hooks: From 83d29209f70f573bbde2e6f04c389a51bded89ed Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 22:25:59 +0900 Subject: [PATCH 54/59] Add bandit --- .bandit | 3 +++ .pre-commit-config.yaml | 8 ++++++++ 2 files changed, 11 insertions(+) create mode 100644 .bandit diff --git a/.bandit b/.bandit new file mode 100644 index 000000000..2d81286ae --- /dev/null +++ b/.bandit @@ -0,0 +1,3 @@ +[bandit] +exclude = tests +skips = B101 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4acdba261..b0eb2db49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,11 @@ repos: [ 'types-PyYAML', ] +- repo: https://github.com/PyCQA/bandit + rev: 1.7.5 + hooks: + - id: bandit + args: [ + '--ini', + '.bandit', + ] From a1f9850b91c34cdb5d819351fdcc36f9bbf9221f Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 22:26:26 +0900 Subject: [PATCH 55/59] Fix security issue or ignore false positives --- scripts/finetune.py | 4 ++-- src/axolotl/prompt_tokenizers.py | 8 ++++---- src/axolotl/utils/data.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 4716744b2..6c42b3061 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -136,7 +136,7 @@ def train( # load the config from the yaml file with open(config, encoding="utf-8") as file: - cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader)) + cfg: DictDefault = DictDefault(yaml.safe_load(file)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() @@ -185,7 +185,7 @@ def train( logging.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec ), tokenizer, ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 3acae91b8..582c35ebd 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -11,10 +11,10 @@ from transformers import PreTrainedTokenizer from axolotl.prompters import IGNORE_TOKEN_ID IGNORE_INDEX = -100 -LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" -LLAMA_DEFAULT_EOS_TOKEN = "" -LLAMA_DEFAULT_BOS_TOKEN = "" -LLAMA_DEFAULT_UNK_TOKEN = "" +LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec +LLAMA_DEFAULT_EOS_TOKEN = "" # nosec +LLAMA_DEFAULT_BOS_TOKEN = "" # nosec +LLAMA_DEFAULT_UNK_TOKEN = "" # nosec class InvalidDataException(Exception): diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index c505cccfa..9534323de 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -40,7 +40,7 @@ def load_tokenized_prepared_datasets( ) -> DatasetDict: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( - md5( + md5( # nosec ( str(cfg.sequence_len) + "@" @@ -66,7 +66,7 @@ def load_tokenized_prepared_datasets( use_auth_token=use_auth_token, ) dataset = dataset["train"] - except Exception: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except # nosec pass if dataset: @@ -272,7 +272,7 @@ def load_prepare_datasets( # see if we can go ahead and load the stacked dataset seed = f"@{str(cfg.seed)}" if cfg.seed else "" ds_hash = str( - md5( + md5( # nosec ( str(cfg.sequence_len) + "@" @@ -304,7 +304,7 @@ def load_prepare_datasets( use_auth_token=use_auth_token, ) dataset = dataset["train"] - except Exception: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except # nosec pass if dataset: From cfcc549f6b5e15ee1dc45e47806978e91bebbe3a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 30 May 2023 10:38:20 -0400 Subject: [PATCH 56/59] fix relative path for fixtures --- src/axolotl/utils/models.py | 5 +++-- src/axolotl/utils/trainer.py | 3 ++- tests/test_prompt_tokenizers.py | 15 +++++++++++++-- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8ce39b8bc..2366854bf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -129,6 +129,7 @@ def load_model( llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch_dtype, + bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) @@ -280,8 +281,8 @@ def load_model( # llama is PROBABLY model parallelizable, but the default isn't that it is # so let's only set it for the 4bit, see # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133 - setattr(model, 'is_parallelizable', True) - setattr(model, 'model_parallel', True) + setattr(model, "is_parallelizable", True) + setattr(model, "model_parallel", True) requires_grad = [] for name, param in model.named_parameters(recurse=True): diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4e41d1b61..2986c491b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -125,7 +125,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): output_dir=cfg.output_dir, save_total_limit=3, load_best_model_at_end=( - cfg.val_set_size > 0 + cfg.load_best_model_at_end is not False + and cfg.val_set_size > 0 and save_steps and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 7595ffbe4..a8d0cf816 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,6 +1,8 @@ +"""Module for testing prompt tokenizers.""" import json import logging import unittest + from pathlib import Path from transformers import AutoTokenizer @@ -12,6 +14,10 @@ logging.basicConfig(level="INFO") class TestPromptTokenizationStrategies(unittest.TestCase): + """ + Test class for prompt tokenization strategies. + """ + def setUp(self) -> None: self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer.add_special_tokens( @@ -24,10 +30,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase): def test_sharegpt_integration(self): print(Path(__file__).parent) - with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin: + with open( + Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" + ) as fin: data = fin.read() conversation = json.loads(data) - with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin: + with open( + Path(__file__).parent / "fixtures/conversation.tokenized.json", + encoding="utf-8", + ) as fin: data = fin.read() tokenized_conversation = json.loads(data) prompter = ShareGPTPrompter("chat") From 25eeeeba0b891cf8be24180cf3373603d84490be Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 31 May 2023 00:38:08 +0900 Subject: [PATCH 57/59] Fix sharegpt prompt --- src/axolotl/prompt_tokenizers.py | 15 ++++++++------- src/axolotl/prompters.py | 10 +++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 582c35ebd..8b3c88fee 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): ] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) + elif part[0] == "SYSTEM:": + part = part[1] # Ignore the system role from preamble + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + part.strip(), add_eos_token=False, strip_bos_token=False + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) else: logging.warning(f"unhandled role: {part[0]}") - else: - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - part.strip(), add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) # pylint: disable=duplicate-code result, current_len = parse_tokenized_to_result( diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 1a2535e19..39c74023b 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -3,7 +3,7 @@ import dataclasses import logging from enum import Enum, auto -from typing import Generator, List, Optional, Union +from typing import Generator, List, Optional, Tuple, Union IGNORE_TOKEN_ID = -100 @@ -235,16 +235,16 @@ class Conversation: sep: str = "###" sep2: Optional[str] = None - def get_prompt(self) -> Generator[str, None, None]: + def get_prompt(self) -> Generator[Tuple[str, str], None, None]: # seps = [self.sep, self.sep2] preamble = self.system + self.sep - yield preamble + yield ("SYSTEM:", preamble) for _, (role, message) in enumerate(self.messages): if message: - yield role + ":" + " " + message + yield (role + ":", " " + message) else: logging.warning(f"role with empty message: {role}") - yield role + ":" + yield (role + ":", "") def copy(self): return Conversation( From 594e72b6e8f3c6182c3828cc7f077aed82e07e62 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 31 May 2023 02:58:50 +0900 Subject: [PATCH 58/59] Fix incorrect rebase --- src/axolotl/utils/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2366854bf..0737d0f12 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -129,7 +129,6 @@ def load_model( llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch_dtype, - bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) From b81c97ff7669fd41872e75b562aa466159b53554 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 31 May 2023 03:01:38 +0900 Subject: [PATCH 59/59] Fix pre-commit for rebased files --- tests/fixtures/conversation.tokenized.json | 2 +- tests/test_prompt_tokenizers.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/fixtures/conversation.tokenized.json b/tests/fixtures/conversation.tokenized.json index 5474624ad..0ac93713b 100644 --- a/tests/fixtures/conversation.tokenized.json +++ b/tests/fixtures/conversation.tokenized.json @@ -1 +1 @@ -{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2]} \ No newline at end of file +{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2]} diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index a8d0cf816..fa85fe5f6 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,7 +2,6 @@ import json import logging import unittest - from pathlib import Path from transformers import AutoTokenizer