diff --git a/.bandit b/.bandit
new file mode 100644
index 000000000..2d81286ae
--- /dev/null
+++ b/.bandit
@@ -0,0 +1,3 @@
+[bandit]
+exclude = tests
+skips = B101
diff --git a/.flake8 b/.flake8
new file mode 100644
index 000000000..fd69af775
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,5 @@
+[flake8]
+max-line-length = 88
+
+select = C,E,F,W,B,B950
+extend-ignore = E203, E501, W503
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
new file mode 100644
index 000000000..626edc686
--- /dev/null
+++ b/.github/workflows/pre-commit.yml
@@ -0,0 +1,16 @@
+name: pre-commit
+
+on:
+ pull_request:
+ push:
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
+ with:
+ python-version: "3.9"
+ cache: 'pip' # caching pip dependencies
+ - uses: pre-commit/action@v3.0.0
diff --git a/.gitignore b/.gitignore
index 93a4f81b5..614a6676b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,4 +160,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-.idea/
\ No newline at end of file
+.idea/
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 000000000..b9fb3f3e8
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,2 @@
+[settings]
+profile=black
diff --git a/.mypy.ini b/.mypy.ini
new file mode 100644
index 000000000..941046ae8
--- /dev/null
+++ b/.mypy.ini
@@ -0,0 +1,33 @@
+[mypy]
+
+exclude = venv
+
+[mypy-alpaca_lora_4bit.*]
+ignore_missing_imports = True
+
+[mypy-flash_attn.*]
+ignore_missing_imports = True
+
+[mypy-huggingface_hub]
+ignore_missing_imports = True
+
+[mypy-transformers.*]
+ignore_missing_imports = True
+
+[mypy-peft]
+ignore_missing_imports = True
+
+[mypy-bitsandbytes]
+ignore_missing_imports = True
+
+[mypy-datasets]
+ignore_missing_imports = True
+
+[mypy-fire]
+ignore_missing_imports = True
+
+[mypy-setuptools]
+ignore_missing_imports = True
+
+[mypy-addict]
+ignore_missing_imports = True
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 000000000..b0eb2db49
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,42 @@
+default_language_version:
+ python: python3.9
+
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: check-yaml
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+- repo: https://github.com/psf/black
+ rev: 23.3.0
+ hooks:
+ - id: black
+- repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+- repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+- repo: https://github.com/PyCQA/pylint
+ rev: v2.17.4
+ hooks:
+ - id: pylint
+- repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v1.3.0
+ hooks:
+ - id: mypy
+ additional_dependencies:
+ [
+ 'types-PyYAML',
+ ]
+- repo: https://github.com/PyCQA/bandit
+ rev: 1.7.5
+ hooks:
+ - id: bandit
+ args: [
+ '--ini',
+ '.bandit',
+ ]
diff --git a/.pylintrc b/.pylintrc
new file mode 100644
index 000000000..ed973d285
--- /dev/null
+++ b/.pylintrc
@@ -0,0 +1,14 @@
+[MASTER]
+init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
+
+[TYPECHECK]
+
+# List of members which are set dynamically and missed by Pylint inference
+# system, and so shouldn't trigger E1101 when accessed.
+generated-members=numpy.*, torch.*
+
+
+[pylint.messages_control]
+disable=missing-function-docstring, line-too-long, import-error,
+ too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
+ too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
diff --git a/README.md b/README.md
index adc3c5812..66fdf3221 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,8 @@
Go ahead and axolotl questions!!
+
+
@@ -406,3 +408,12 @@ Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
+
+Please run below to setup env
+```bash
+pip3 install -r requirements-dev.txt -r requirements-tests.txt
+pre-commit install
+
+# test
+pytest tests/
+```
diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base
index a61f6d42d..0ce43b621 100644
--- a/docker/Dockerfile-base
+++ b/docker/Dockerfile-base
@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic
-
diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml
index 1291198cf..090cc6bcf 100644
--- a/examples/falcon/config-7b-lora.yml
+++ b/examples/falcon/config-7b-lora.yml
@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"
-
diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml
index 787c4121c..dc67d6125 100644
--- a/examples/falcon/config-7b.yml
+++ b/examples/falcon/config-7b.yml
@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"
-
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 000000000..df7e312cb
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,3 @@
+pre-commit
+black
+mypy
diff --git a/requirements.txt b/requirements.txt
index 27b31a139..20a5feb42 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,7 +4,6 @@ bitsandbytes>=0.39.0
addict
fire
PyYAML==6.0
-black
datasets
accelerate>=0.19.0
sentencepiece
diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py
index 98c968309..61cb170ec 100644
--- a/scripts/alpaca_json_to_jsonl.py
+++ b/scripts/alpaca_json_to_jsonl.py
@@ -1,24 +1,38 @@
+"""Module to convert json file to jsonl"""
+
import os
import sys
from pathlib import Path
+from typing import Optional, Union
import fire
-from typing import Optional
+
+from axolotl.convert import (
+ FileReader,
+ FileWriter,
+ JsonlSerializer,
+ JsonParser,
+ JsonToJsonlConverter,
+ StdoutWriter,
+)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
-from axolotl.convert import *
-
def main(
- input: Path,
+ file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
+ """
+ Convert a json file to jsonl
+ """
+
file_reader = FileReader()
+ writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
@@ -28,7 +42,7 @@ def main(
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
- converter.convert(input, output)
+ converter.convert(file, output)
if __name__ == "__main__":
diff --git a/scripts/finetune.py b/scripts/finetune.py
index 58f1c0957..6c42b3061 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -1,3 +1,5 @@
+"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
+
import importlib
import logging
import os
@@ -5,25 +7,26 @@ import random
import signal
import sys
from pathlib import Path
-from typing import Optional, List, Dict, Any, Union
+from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
+from axolotl.utils.data import load_prepare_datasets
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.models import load_model, load_tokenizer
+
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
+from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
-from axolotl.utils.dict import DictDefault
+from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
-from axolotl.utils.data import load_prepare_datasets
-from axolotl.utils.models import load_model, load_tokenizer
-from axolotl.utils.trainer import setup_trainer
-from axolotl.utils.wandb import setup_wandb_env_vars
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -31,14 +34,16 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
- if torch.cuda.is_available():
- return f"cuda:{cfg.local_rank}"
- else:
- try:
- if torch.backends.mps.is_available():
- return "mps"
- except:
- return "cpu"
+ try:
+ if torch.cuda.is_available():
+ return f"cuda:{cfg.local_rank}"
+
+ if torch.backends.mps.is_available():
+ return "mps"
+
+ raise SystemError("No CUDA/mps device found")
+ except Exception: # pylint: disable=broad-exception-caught
+ return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
@@ -51,7 +56,7 @@ def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
- instruction += line
+ instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
@@ -92,7 +97,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
def choose_config(path: Path):
- yaml_files = [file for file in path.glob("*.yml")]
+ yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
@@ -130,12 +135,12 @@ def train(
config = choose_config(config)
# load the config from the yaml file
- with open(config, "r") as f:
- cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
+ with open(config, encoding="utf-8") as file:
+ cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
- for k in kwargs:
+ for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
# handle booleans
@@ -167,13 +172,11 @@ def train(
# load the tokenizer first
logging.info("loading tokenizer...")
- tokenizer = load_tokenizer(
- cfg.base_model_config,
- cfg.tokenizer_type,
- cfg
- )
+ tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
- if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
+ if check_not_in(
+ ["inference", "shard", "merge_lora"], kwargs
+ ): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -182,7 +185,7 @@ def train(
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
- [random.randrange(0, len(train_dataset) - 1) for i in range(5)]
+ [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
@@ -239,7 +242,10 @@ def train(
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
- lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
+ lambda signal, frame: (
+ model.save_pretrained(cfg.output_dir),
+ sys.exit(0),
+ ),
)
logging.info("Starting trainer...")
@@ -252,7 +258,8 @@ def train(
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
- possible_checkpoints, key=lambda path: int(path.split("-")[-1])
+ possible_checkpoints,
+ key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
@@ -266,6 +273,7 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
+
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
diff --git a/setup.py b/setup.py
index 134e4be66..de9fdc62f 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,9 @@
-from setuptools import setup, find_packages
+"""setup.py for axolotl"""
+
+from setuptools import find_packages, setup
install_requires = []
-with open("./requirements.txt", "r") as requirements_file:
+with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py
index a953252e9..357e0ec50 100644
--- a/src/axolotl/convert.py
+++ b/src/axolotl/convert.py
@@ -1,47 +1,76 @@
+"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
+
+
import json
import sys
class FileReader:
+ """
+ Reads a file and returns its contents as a string
+ """
+
def read(self, file_path):
- with open(file_path, "r") as file:
+ with open(file_path, encoding="utf-8") as file:
return file.read()
class FileWriter:
+ """
+ Writes a string to a file
+ """
+
def __init__(self, file_path):
self.file_path = file_path
def write(self, content):
- with open(self.file_path, "w") as file:
+ with open(self.file_path, "w", encoding="utf-8") as file:
file.write(content)
class StdoutWriter:
+ """
+ Writes a string to stdout
+ """
+
def write(self, content):
sys.stdout.write(content)
sys.stdout.write("\n")
class JsonParser:
+ """
+ Parses a string as JSON and returns the result
+ """
+
def parse(self, content):
return json.loads(content)
class JsonlSerializer:
+ """
+ Serializes a list of JSON objects into a JSONL string
+ """
+
def serialize(self, data):
lines = [json.dumps(item) for item in data]
return "\n".join(lines)
class JsonToJsonlConverter:
+ """
+ Converts a JSON file to JSONL
+ """
+
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
self.file_reader = file_reader
self.file_writer = file_writer
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
- def convert(self, input_file_path, output_file_path):
+ def convert(
+ self, input_file_path, output_file_path
+ ): # pylint: disable=unused-argument
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py
index 0e166f6f0..fb5e15656 100644
--- a/src/axolotl/datasets.py
+++ b/src/axolotl/datasets.py
@@ -1,10 +1,12 @@
+"""Module containing Dataset functionality"""
+
import logging
from typing import List
import torch
from datasets import IterableDataset
-from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
+from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
@@ -14,7 +16,14 @@ from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
class TokenizedPromptDataset(IterableDataset):
- def __init__(
+ """
+ Iterable dataset that returns tokenized prompts from a stream of text files.
+ Args:
+ prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
+ dataset (dataset.Dataset): Dataset with text files.
+ """
+
+ def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
@@ -42,7 +51,7 @@ class ConstantLengthDataset(IterableDataset):
seq_length (int): Length of token sequences to return.
"""
- def __init__(
+ def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
@@ -82,10 +91,8 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
- if (
- not example_len
- or buffer_len + int(add_concat_token) + example_len
- > self.seq_length
+ if not example_len or (
+ buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +102,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
- if (
- labels.size() == input_ids.size()
- and attention_mask.size() == input_ids.size()
+ if labels.size() == input_ids.size() and (
+ attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
@@ -108,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
logging.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
- buffer = {"input_ids": [], "attention_mask": [], "labels": []}
+ buffer = {
+ "input_ids": [],
+ "attention_mask": [],
+ "labels": [],
+ }
buffer_len = 0
if example:
diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py
index c1ceec788..6df0b8e18 100644
--- a/src/axolotl/flash_attn.py
+++ b/src/axolotl/flash_attn.py
@@ -1,17 +1,15 @@
+"""Flash attention monkey patch for llama model"""
+
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
-from typing import List, Optional, Tuple
+from typing import Optional, Tuple
import torch
-from torch import nn
-
import transformers
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
-
from einops import rearrange
-
+from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
-from flash_attn.bert_padding import unpad_input, pad_input
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
@@ -74,7 +72,11 @@ def forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
- 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ 0,
+ (bsz + 1) * q_len,
+ step=q_len,
+ dtype=torch.int32,
+ device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
@@ -82,35 +84,56 @@ def forward(
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
+
+ # pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
- x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
+ x_unpad,
+ "nnz (three h d) -> nnz three h d",
+ three=3,
+ h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
- x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ x_unpad,
+ cu_q_lens,
+ max_s,
+ 0.0,
+ softmax_scale=None,
+ causal=True,
)
output = rearrange(
pad_input(
- rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
+ indices,
+ bsz,
+ q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
- return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
+ return (
+ self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
+ None,
+ None,
+ )
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
-):
+ self,
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py
index 803eb970c..2f6af208c 100644
--- a/src/axolotl/prompt_strategies/__init__.py
+++ b/src/axolotl/prompt_strategies/__init__.py
@@ -1,3 +1,5 @@
+"""Module to load prompt strategies."""
+
import importlib
@@ -7,8 +9,8 @@ def load(strategy, tokenizer, cfg):
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
- m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
- fn = getattr(m, load_fn)
- return fn(tokenizer, cfg)
- except:
- pass
+ mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
+ func = getattr(mod, load_fn)
+ return func(tokenizer, cfg)
+ except Exception: # pylint: disable=broad-exception-caught
+ return None
diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py
index 7b6ccea7d..15dfb65c4 100644
--- a/src/axolotl/prompt_strategies/alpaca_chat.py
+++ b/src/axolotl/prompt_strategies/alpaca_chat.py
@@ -1,3 +1,7 @@
+"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
+
+from typing import Tuple
+
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
@@ -7,7 +11,7 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
- AlpacaPrompter(PromptStyle.chat.value),
+ AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -15,7 +19,11 @@ def load(tokenizer, cfg):
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for AlpacaQA
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"",
@@ -25,7 +33,7 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
- AlpacaPrompter(PromptStyle.chat.value),
+ AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
diff --git a/src/axolotl/prompt_strategies/alpaca_instruct.py b/src/axolotl/prompt_strategies/alpaca_instruct.py
index 6bce47ccd..2e42191f8 100644
--- a/src/axolotl/prompt_strategies/alpaca_instruct.py
+++ b/src/axolotl/prompt_strategies/alpaca_instruct.py
@@ -1,10 +1,12 @@
+"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
+
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
- AlpacaPrompter(PromptStyle.instruct),
+ AlpacaPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py
index 58e8b2bee..ea67034b3 100644
--- a/src/axolotl/prompt_strategies/creative_acr.py
+++ b/src/axolotl/prompt_strategies/creative_acr.py
@@ -1,11 +1,18 @@
-from typing import Union, Generator
+"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
+
+from typing import Generator, Tuple, Union
import yaml
+
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for Creative Answering
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
question = prompt["instruction"]
answer = prompt[
"revision"
@@ -18,6 +25,10 @@ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrat
class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
+ """
+ Tokenizing strategy for Creative Critique
+ """
+
user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
@@ -49,12 +60,16 @@ Question: {question}
Answer: {answer}
"""
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
- prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
+ prompt["scores"],
+ default_flow_style=False,
+ Dumper=yaml.Dumper,
)
critiques = yaml.dump(
- prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
+ prompt["critiques"],
+ default_flow_style=False,
+ Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -67,6 +82,10 @@ Answer: {answer}
class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
+ """
+ Tokenizing strategy for Creative Revise
+ """
+
user_prompt = """Definitions:
refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
@@ -81,12 +100,16 @@ Evaluation:
{evaluation}
"""
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
- prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
+ prompt["scores"],
+ default_flow_style=False,
+ Dumper=yaml.Dumper,
)
critiques = yaml.dump(
- prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
+ prompt["critiques"],
+ default_flow_style=False,
+ Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -101,13 +124,19 @@ Evaluation:
class CreativePrompterBase:
+ """
+ Base class for Creative Prompters
+ """
+
system_prompt = ""
prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
def build_prompt(
self,
instruction: str,
- input: Union[None, str] = None,
+ input: Union[ # pylint: disable=redefined-builtin, unused-argument
+ None, str
+ ] = None,
output: Union[None, str] = None,
) -> Generator[str, None, None]:
if self.system_prompt:
@@ -120,30 +149,51 @@ class CreativePrompterBase:
class CreativeAnswerPrompter(CreativePrompterBase):
+ """
+ Prompter for Creative Answering
+ """
+
system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
class CreativeCritiquePrompter(CreativePrompterBase):
+ """
+ Prompter for Creative Critique
+ """
+
system_prompt = ""
class CreativeRevisePrompter(CreativePrompterBase):
+ """
+ Prompter for Creative Revise
+ """
+
system_prompt = ""
def load_answer(tokenizer, cfg):
return CreativeAnsweringPromptTokenizingStrategy(
- CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
+ CreativeAnswerPrompter(),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
)
def load_critique(tokenizer, cfg):
return CreativeCritiquePromptTokenizingStrategy(
- CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
+ CreativeCritiquePrompter(),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
)
def load_revise(tokenizer, cfg):
return CreativeRevisePromptTokenizingStrategy(
- CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
+ CreativeRevisePrompter(),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
)
diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py
index ced15c3cf..d38bc2beb 100644
--- a/src/axolotl/prompt_strategies/pygmalion.py
+++ b/src/axolotl/prompt_strategies/pygmalion.py
@@ -1,29 +1,34 @@
+"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
+
import copy
import logging
from collections import defaultdict
-from typing import Generator
+from typing import Generator, List, Tuple
-from axolotl.prompt_tokenizers import PromptTokenizingStrategy
+from axolotl.prompt_tokenizers import (
+ PromptTokenizingStrategy,
+ parse_tokenized_to_result,
+ tokenize_prompt_default,
+)
IGNORE_TOKEN_ID = -100
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
- bot_prefix_token_ids = []
+ """
+ Tokenizing strategy for Pygmalion.
+ """
+
+ bot_prefix_token_ids: List[int] = []
def __init__(self, prompter, tokenizer, *args, **kwargs):
- super().__init__(prompter, tokenizer)
+ super().__init__(prompter, tokenizer, *args, **kwargs)
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
self.bot_prefix_token_ids = res["input_ids"]
def tokenize_prompt(self, prompt):
- result = {
- "input_ids": [],
- "attention_mask": [],
- "labels": [],
- }
- current_len = 0
- for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
+ result, current_len = tokenize_prompt_default()
+ for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
role, message = part
if role == "system":
prefix = "<|system|>"
@@ -61,45 +66,29 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
else:
logging.warning(f"unknown role in conversation: {role}")
res = defaultdict(lambda: [])
- input_ids = res["input_ids"]
- input_len = len(input_ids)
- result["input_ids"][current_len : current_len + input_len] = input_ids
- result["attention_mask"][current_len : current_len + input_len] = [
- 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
- ]
- result["labels"][current_len : current_len + input_len] = labels
- current_len += input_len
- return result
- def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
- result = self.tokenizer(
- prompt,
- truncation=True,
- max_length=self.sequence_len,
- padding=False,
- return_tensors=None,
- )
- if (
- result["input_ids"][-1] != self.tokenizer.eos_token_id
- and len(result["input_ids"]) < self.sequence_len
- and add_eos_token
- ):
- result["input_ids"].append(self.tokenizer.eos_token_id)
- result["attention_mask"].append(1)
-
- if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
- result["input_ids"] = result["input_ids"][1:]
- result["attention_mask"] = result["attention_mask"][1:]
-
- result["labels"] = result["input_ids"].copy()
+ # pylint: disable=duplicate-code
+ result, current_len = parse_tokenized_to_result(
+ result,
+ current_len,
+ res,
+ labels,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
return result
class PygmalionPrompter:
+ """
+ Prompter for Pygmalion.
+ """
+
def __init__(self, *args, **kwargs):
pass
- def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
+ def build_prompt(
+ self, source, *args, **kwargs # pylint: disable=unused-argument
+ ) -> Generator[Tuple[str, str], None, None]:
for msg in source:
yield msg["role"], msg["value"]
diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py
index a91a4e2d3..8b3c88fee 100644
--- a/src/axolotl/prompt_tokenizers.py
+++ b/src/axolotl/prompt_tokenizers.py
@@ -1,24 +1,33 @@
+"""Module containing PromptTokenizingStrategy and Prompter classes"""
+
import abc
import copy
import functools
import logging
+from typing import Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer
from axolotl.prompters import IGNORE_TOKEN_ID
IGNORE_INDEX = -100
-LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
-LLAMA_DEFAULT_EOS_TOKEN = ""
-LLAMA_DEFAULT_BOS_TOKEN = ""
-LLAMA_DEFAULT_UNK_TOKEN = ""
+LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
+LLAMA_DEFAULT_EOS_TOKEN = "" # nosec
+LLAMA_DEFAULT_BOS_TOKEN = "" # nosec
+LLAMA_DEFAULT_UNK_TOKEN = "" # nosec
class InvalidDataException(Exception):
- pass
+ """
+ Exception raised when the data is invalid
+ """
class PromptTokenizingStrategy(abc.ABC):
+ """
+ Abstract class for tokenizing strategies
+ """
+
def __init__(
self,
prompter,
@@ -35,59 +44,21 @@ class PromptTokenizingStrategy(abc.ABC):
def tokenize_prompt(self, prompt):
pass
- @functools.cache
+ @functools.lru_cache(maxsize=128)
def _get_user_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
- @functools.cache
+ @functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
return False
-
-class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
- raise NotImplementedError
-
- def tokenize_prompt(self, prompt):
- instruction, input, response = self.parse_instruction_fields(prompt)
- full_prompt = self._build_full_prompt(instruction, input, response)
- tokenized_full_prompt = self._tokenize(full_prompt)
- if not self.train_on_inputs:
- user_prompt = next(
- iter(
- self.prompter.build_prompt(
- instruction,
- input,
- )
- )
- )
- tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
- # TODO this could be sped up using numpy array slicing
- tokenized_full_prompt["labels"] = [
- -100
- ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
-
- return tokenized_full_prompt
-
- def _build_full_prompt(self, instruction, input, response):
- return next(
- iter(
- self.prompter.build_prompt(
- instruction,
- input,
- response,
- )
- )
- )
-
- def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
+ def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
@@ -111,8 +82,60 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
return result
+class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
+ """
+ Tokenizing strategy for instruction-based prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
+ raise NotImplementedError
+
+ def tokenize_prompt(self, prompt):
+ (
+ instruction,
+ input, # pylint: disable=redefined-builtin
+ response,
+ ) = self.parse_instruction_fields(prompt)
+ full_prompt = self._build_full_prompt(instruction, input, response)
+ tokenized_full_prompt = self._tokenize(full_prompt)
+ if not self.train_on_inputs:
+ user_prompt = next(
+ iter(
+ self.prompter.build_prompt(
+ instruction,
+ input,
+ )
+ )
+ )
+ tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
+ # TODO this could be sped up using numpy array slicing
+ tokenized_full_prompt["labels"] = [
+ -100
+ ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
+
+ return tokenized_full_prompt
+
+ def _build_full_prompt(
+ self, instruction, input, response # pylint: disable=redefined-builtin
+ ):
+ return next(
+ iter(
+ self.prompter.build_prompt(
+ instruction,
+ input,
+ response,
+ )
+ )
+ )
+
+
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for Alpaca prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for Alpaca Multiple Choice prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
@@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for Jeopardy prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
prompt["category"],
@@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for OpenAssistant prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["INSTRUCTION"],
"",
@@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for SummarizeTLDR prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["article"],
"",
@@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for GPTeacher prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str):
+ """
+ Tokenizing strategy for NomicGPT4All prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["prompt"],
"",
@@ -175,28 +222,34 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> str:
- return prompt["text"]
+ """
+ Tokenizing strategy for Completion prompts.
+ """
def tokenize_prompt(self, prompt):
- instruction = self.parse_instruction_fields(prompt)
- full_prompt = self._build_full_prompt(instruction, None, None)
+ full_prompt = self._build_full_prompt(prompt["text"], None, None)
tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt
- def _build_full_prompt(self, instruction, input, response):
- return next(iter(self.prompter.build_prompt(instruction)))
+ def _build_full_prompt(
+ self, instruction, input, response
+ ): # pylint: disable=redefined-builtin
+ return next(iter(self.prompter.build_prompt(instruction, input, response)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
+ """
+ Tokenizing strategy for Reflection prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
raise NotImplementedError
def tokenize_prompt(self, prompt):
(
instruction,
- input,
+ input, # pylint: disable=redefined-builtin
output,
reflection,
corrected,
@@ -223,7 +276,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
- def _build_full_prompt(self, instruction, input, output, reflection, corrected):
+ def _build_full_prompt(
+ self, instruction, input, output, reflection, corrected
+ ): # pylint: disable=redefined-builtin
return next(
iter(
self.prompter.build_prompt(
@@ -236,7 +291,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
)
)
- def _tokenize(self, prompt, add_eos_token=True):
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
prompt,
truncation=True,
@@ -257,7 +312,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
- def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
+ """
+ Tokenizing strategy for Alpaca Reflection prompts.
+ """
+
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
@@ -268,20 +327,19 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
+ """
+ Tokenizing strategy for ShareGPT prompts.
+ """
+
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
- result = {
- "input_ids": [],
- "attention_mask": [],
- "labels": [],
- }
- current_len = 0
+ result, current_len = tokenize_prompt_default()
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
try:
- for i, part in enumerate(
+ for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
@@ -289,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
res = self._tokenize(
- part.strip(), add_eos_token=False, strip_bos_token=True
+ part.strip(),
+ add_eos_token=False,
+ strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
@@ -300,32 +360,39 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
res = self._tokenize(
- part.strip(), add_eos_token=True, strip_bos_token=True
+ part.strip(),
+ add_eos_token=True,
+ strip_bos_token=True,
)
if assistant_token:
- res["input_ids"] = [assistant_token, *res["input_ids"]]
+ res["input_ids"] = [
+ assistant_token,
+ *res["input_ids"],
+ ]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
+ elif part[0] == "SYSTEM:":
+ part = part[1] # Ignore the system role from preamble
+ # this is only ever the first part, should include the bos token and the user query
+ res = self._tokenize(
+ part.strip(), add_eos_token=False, strip_bos_token=False
+ )
+ # everything from this is masked out from the labels
+ labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
- logging.warning("unhandled role: " + part[0])
- else:
- # this is only ever the first part, should include the bos token and the user query
- res = self._tokenize(
- part.strip(), add_eos_token=False, strip_bos_token=False
- )
- # everything from this is masked out from the labels
- labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
- input_ids = res["input_ids"]
- input_len = len(input_ids)
- result["input_ids"][current_len : current_len + input_len] = input_ids
- result["attention_mask"][current_len : current_len + input_len] = [
- 1 if x != self.tokenizer.pad_token_id else 0 for x in input_ids
- ]
- result["labels"][current_len : current_len + input_len] = labels
- current_len += input_len
+ logging.warning(f"unhandled role: {part[0]}")
+
+ # pylint: disable=duplicate-code
+ result, current_len = parse_tokenized_to_result(
+ result,
+ current_len,
+ res,
+ labels,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
return result
- except (KeyError, AssertionError, IndexError) as e:
- raise InvalidDataException(str(e))
+ except (KeyError, AssertionError, IndexError) as err:
+ raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
result = self.tokenizer(
@@ -349,3 +416,40 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
result["labels"] = result["input_ids"].copy()
return result
+
+
+def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
+ """
+ Returns the default values for the tokenize prompt function
+ """
+
+ result: Dict[str, List[int]] = {
+ "input_ids": [],
+ "attention_mask": [],
+ "labels": [],
+ }
+ current_len = 0
+ return result, current_len
+
+
+def parse_tokenized_to_result(
+ result: Dict[str, List[int]],
+ current_len: int,
+ res: Dict[str, List[int]],
+ labels: list[int],
+ pad_token_id: Union[int, None] = None,
+) -> Tuple[Dict[str, List[int]], int]:
+ """
+ Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
+ """
+
+ input_ids = res["input_ids"]
+ input_len = len(input_ids)
+ result["input_ids"][current_len : current_len + input_len] = input_ids
+ result["attention_mask"][current_len : current_len + input_len] = [
+ 1 if x != pad_token_id else 0 for x in input_ids
+ ]
+ result["labels"][current_len : current_len + input_len] = labels
+ current_len += input_len
+
+ return result, current_len
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 760c714d6..39c74023b 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -1,28 +1,37 @@
-import copy
+"""Module containing prompters"""
+
import dataclasses
import logging
-from enum import auto, Enum
-from typing import List, Tuple, Any, Union, Generator
+from enum import Enum, auto
+from typing import Generator, List, Optional, Tuple, Union
IGNORE_TOKEN_ID = -100
class PromptStyle(Enum):
- instruct = "instruct"
- chat = "chat"
+ """
+ Enum for prompt styles
+ """
+
+ INSTRUCT = "instruct"
+ CHAT = "chat"
class AlpacaPrompter:
+ """
+ Base class for alpaca prompters
+ """
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
- prompt_style = None
+ prompt_style: Optional[PromptStyle] = None
- def __init__(self, prompt_style=PromptStyle.instruct.value):
- self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
+ def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
+ self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()
def match_prompt_style(self):
- if self.prompt_style == PromptStyle.instruct.value:
+ if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
@@ -32,7 +41,7 @@ class AlpacaPrompter:
+ "### Instruction:\n{instruction}\n\n### Response:\n"
)
self.response_split = "### Response:"
- if self.prompt_style == PromptStyle.chat.value:
+ if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
@@ -44,7 +53,7 @@ class AlpacaPrompter:
def build_prompt(
self,
instruction: str,
- input: Union[None, str] = None,
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
@@ -62,33 +71,60 @@ class AlpacaPrompter:
class UnpromptedPrompter(AlpacaPrompter):
+ """
+ Prompter for alpaca no system prompt
+ """
+
system_prompt = ""
system_no_input_prompt = ""
class JeopardyPrompter(AlpacaPrompter):
+ """
+ Prompter for Jeopardy
+ """
+
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
class MultipleChoiceExplainPrompter(AlpacaPrompter):
+ """
+ Prompter for multiple choice explain
+ """
+
system_prompt = (
"Choose the answer that best answers the question. Explain your reasoning."
)
class MultipleChoiceConcisePrompter(AlpacaPrompter):
+ """
+ Prompter for multiple choice concise
+ """
+
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
class SummarizeTLDRPrompter(AlpacaPrompter):
+ """
+ Prompter for summarize TLDR
+ """
+
prompt_no_input = (
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
)
class CompletionPrompter:
+ """
+ Prompter for completion
+ """
+
def build_prompt(
- self, instruction: str, input=None, output=None
+ self,
+ instruction: str,
+ input=None, # pylint: disable=redefined-builtin, unused-argument
+ output=None, # pylint: disable=unused-argument
) -> Generator[str, None, None]:
yield instruction
@@ -97,14 +133,22 @@ class CompletionPrompter:
class GPTeacherPrompter(AlpacaPrompter):
- ...
+ """
+ Prompter for GPTeacher
+ """
class NomicGPT4AllPrompter(AlpacaPrompter):
- ...
+ """
+ Prompter for NomicGPT4All
+ """
class ReflectAlpacaPrompter:
+ """
+ Prompter for ReflectAlpaca
+ """
+
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n"
@@ -120,7 +164,7 @@ class ReflectAlpacaPrompter:
self.match_prompt_style()
def match_prompt_style(self):
- if self.prompt_style == PromptStyle.instruct.value:
+ if self.prompt_style == PromptStyle.INSTRUCT.value:
self.prompt_input = (
self.system_prompt
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
@@ -131,7 +175,7 @@ class ReflectAlpacaPrompter:
)
self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
self.response_split = "### Final Response:"
- if self.prompt_style == PromptStyle.chat.value:
+ if self.prompt_style == PromptStyle.CHAT.value:
self.prompt_input = (
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
)
@@ -146,7 +190,7 @@ class ReflectAlpacaPrompter:
def build_prompt(
self,
instruction: str,
- input: Union[None, str] = None,
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
@@ -159,7 +203,9 @@ class ReflectAlpacaPrompter:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(
- output=output, reflection=reflection, corrected=corrected
+ output=output,
+ reflection=reflection,
+ corrected=corrected,
)
res = f"{res}{label}"
yield res
@@ -187,18 +233,18 @@ class Conversation:
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
- sep2: str = None
+ sep2: Optional[str] = None
- def get_prompt(self) -> Generator[str, None, None]:
- seps = [self.sep, self.sep2]
- preamble = self.system + seps[0]
- yield preamble
- for i, (role, message) in enumerate(self.messages):
+ def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
+ # seps = [self.sep, self.sep2]
+ preamble = self.system + self.sep
+ yield ("SYSTEM:", preamble)
+ for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
else:
- logging.warning("role with empty message: " + role)
- yield (role + ":",)
+ logging.warning(f"role with empty message: {role}")
+ yield (role + ":", "")
def copy(self):
return Conversation(
@@ -227,10 +273,14 @@ conv_vicuna_v1_1 = Conversation(
)
-class ShareGPTPrompter:
+class ShareGPTPrompter: # pylint: disable=too-few-public-methods
+ """
+ A prompter that generates prompts for the ShareGPT
+ """
+
def __init__(self, prompt_style=None):
- if prompt_style != PromptStyle.chat.value:
- raise Exception(
+ if prompt_style != PromptStyle.CHAT.value:
+ raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
@@ -240,7 +290,7 @@ class ShareGPTPrompter:
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
# self.response_split = "ASSISTANT:"
- def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
+ def build_prompt(self, source) -> Generator[str, None, None]:
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)
@@ -261,9 +311,9 @@ class ShareGPTPrompter:
):
# Skip the first one if it is not from human
source = source[1:]
- except IndexError as e:
+ except IndexError as err:
# sometimes there is a bing or system chat
- raise e
+ raise err
conv.messages = []
for j, sentence in enumerate(source):
diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py
index 229cd9b98..f6852249a 100644
--- a/src/axolotl/utils/callbacks.py
+++ b/src/axolotl/utils/callbacks.py
@@ -1,16 +1,19 @@
+"""Callbacks for Trainer class"""
+
import os
from transformers import (
- Seq2SeqTrainer,
TrainerCallback,
- TrainingArguments,
- TrainerState,
TrainerControl,
+ TrainerState,
+ TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
-class SavePeftModelCallback(TrainerCallback):
+class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
+ """Callback to save the PEFT adapter"""
+
def on_save(
self,
args: TrainingArguments,
@@ -19,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback):
**kwargs,
):
checkpoint_folder = os.path.join(
- args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
+ args.output_dir,
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index a0cff21c4..9534323de 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -1,42 +1,37 @@
+"""Module containing data utilities"""
+
import logging
from hashlib import md5
from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
-from datasets import (
- load_from_disk,
- load_dataset,
- IterableDataset,
- Dataset,
- concatenate_datasets,
- DatasetDict,
-)
+from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
-from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
+from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
- AlpacaPromptTokenizingStrategy,
- GPTeacherPromptTokenizingStrategy,
- OpenAssistantPromptTokenizingStrategy,
- AlpacaReflectionPTStrategy,
- ShareGPTPromptTokenizingStrategy,
- JeopardyPromptTokenizingStrategy,
- CompletionPromptTokenizingStrategy,
AlpacaMultipleChoicePromptTokenizingStrategy,
+ AlpacaPromptTokenizingStrategy,
+ AlpacaReflectionPTStrategy,
+ CompletionPromptTokenizingStrategy,
+ GPTeacherPromptTokenizingStrategy,
+ JeopardyPromptTokenizingStrategy,
+ OpenAssistantPromptTokenizingStrategy,
+ ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
+ CompletionPrompter,
GPTeacherPrompter,
+ JeopardyPrompter,
+ MultipleChoiceConcisePrompter,
+ MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter,
- JeopardyPrompter,
- CompletionPrompter,
- MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter,
- MultipleChoiceConcisePrompter,
)
@@ -45,11 +40,13 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
- md5(
+ md5( # nosec
(
str(cfg.sequence_len)
+ "@"
- + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ + "|".join(
+ sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
+ )
+ "|"
+ tokenizer_name
).encode("utf-8")
@@ -65,10 +62,11 @@ def load_tokenized_prepared_datasets(
try:
if cfg.push_dataset_to_hub:
dataset = load_dataset(
- f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
+ f"{cfg.push_dataset_to_hub}/{ds_hash}",
+ use_auth_token=use_auth_token,
)
dataset = dataset["train"]
- except:
+ except Exception: # pylint: disable=broad-except # nosec
pass
if dataset:
@@ -81,43 +79,59 @@ def load_tokenized_prepared_datasets(
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
datasets = []
+ # pylint: disable=invalid-name
for d in cfg.datasets:
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:
- load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
+ load_dataset(
+ d.path,
+ streaming=True,
+ use_auth_token=use_auth_token,
+ )
ds_from_hub = True
except FileNotFoundError:
pass
# prefer local dataset, even if hub exists
if Path(d.path).exists():
- ds: Dataset = load_dataset(
- "json", data_files=d.path, streaming=False, split=None
+ ds = load_dataset(
+ "json",
+ data_files=d.path,
+ streaming=False,
+ split=None,
)
elif ds_from_hub:
if d.data_files:
- ds: Dataset = load_dataset(
+ ds = load_dataset(
d.path,
streaming=False,
data_files=d.data_files,
use_auth_token=use_auth_token,
)
else:
- ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token)
+ ds = load_dataset(
+ d.path,
+ streaming=False,
+ use_auth_token=use_auth_token,
+ )
else:
fp = hf_hub_download(
- repo_id=d.path, repo_type="dataset", filename=d.data_files
+ repo_id=d.path,
+ repo_type="dataset",
+ filename=d.data_files,
)
- ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
+ ds = load_dataset("json", data_files=fp, streaming=False, split=None)
if not ds:
- raise Exception("unhandled dataset load")
+ raise ValueError("unhandled dataset load")
# support for using a subset of the data
if d.shards:
if "train" in ds:
- ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
+ ds = ds.shuffle(seed=42)["train"].shard(
+ num_shards=d.shards, index=0
+ )
else:
- ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
+ ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -221,9 +235,9 @@ def load_tokenized_prepared_datasets(
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset")
- samples = []
+ samples: List[int] = []
for d in datasets:
- samples = samples + [i for i in d]
+ samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0:
logging.info(
@@ -242,8 +256,10 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets(
- tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
-) -> (Dataset, Dataset):
+ tokenizer: PreTrainedTokenizerBase,
+ cfg,
+ default_dataset_prepared_path,
+) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
@@ -256,13 +272,15 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
- md5(
+ md5( # nosec
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
- + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ + "|".join(
+ sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
+ )
+ "|"
+ tokenizer_name
).encode("utf-8")
@@ -282,10 +300,11 @@ def load_prepare_datasets(
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset = load_dataset(
- f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
+ f"{cfg.push_dataset_to_hub}/{ds_hash}",
+ use_auth_token=use_auth_token,
)
dataset = dataset["train"]
- except:
+ except Exception: # pylint: disable=broad-except # nosec
pass
if dataset:
@@ -319,7 +338,7 @@ def load_prepare_datasets(
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
- dataset = Dataset.from_list([_ for _ in constant_len_dataset])
+ dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data
dataset = Dataset.from_list(
@@ -343,7 +362,8 @@ def load_prepare_datasets(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
- f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
+ f"{cfg.push_dataset_to_hub}/{ds_hash}",
+ private=True,
)
else:
dataset = load_tokenized_prepared_datasets(
@@ -355,7 +375,8 @@ def load_prepare_datasets(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
)
dataset = dataset.shard(
- num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
+ num_shards=cfg.dataset_shard_num,
+ index=cfg.dataset_shard_idx,
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py
index e3a0a517d..375baf0ea 100644
--- a/src/axolotl/utils/dict.py
+++ b/src/axolotl/utils/dict.py
@@ -1,3 +1,5 @@
+"""Module containing the DictDefault class"""
+
from addict import Dict
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 07872a16e..0737d0f12 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -1,26 +1,22 @@
+"""Module for models and model loading"""
+
+
import logging
import math
import os
from pathlib import Path
-from typing import Optional, Tuple, TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
-from transformers import (
- AutoModelForCausalLM,
- AutoTokenizer,
- PreTrainedModel,
- AutoConfig,
- BitsAndBytesConfig,
-)
+from transformers import AutoModelForCausalLM # noqa: F401
+from transformers import PreTrainedModel # noqa: F401
+from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
try:
- from transformers import (
- LlamaForCausalLM,
- LlamaTokenizer,
- )
-except:
+ from transformers import LlamaForCausalLM
+except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
@@ -28,9 +24,10 @@ except:
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
- from peft import PeftModel, PeftConfig
- from axolotl.utils.dict import DictDefault
- from transformers import PreTrainedTokenizer
+ from peft import PeftConfig # noqa: F401
+ from transformers import PreTrainedTokenizer # noqa: F401
+
+ from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(
@@ -54,7 +51,10 @@ def load_tokenizer(
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
- if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
+ if tokenizer.__class__.__name__ in [
+ "LlamaTokenizer",
+ "LlamaTokenizerFast",
+ ]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -62,8 +62,8 @@ def load_tokenizer(
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if cfg.special_tokens:
- for k, v in cfg.special_tokens.items():
- tokenizer.add_special_tokens({k: v})
+ for k, val in cfg.special_tokens.items():
+ tokenizer.add_special_tokens({k: val})
if cfg.tokens:
tokenizer.add_tokens(list(cfg.tokens))
@@ -79,7 +79,10 @@ def load_model(
adapter="lora",
inference=False,
):
- # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
+ # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
+ """
+ Load a model from a base model and a model type.
+ """
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
@@ -115,9 +118,9 @@ def load_model(
replace_peft_model_with_int4_lora_model()
from peft import prepare_model_for_int8_training
- except Exception as e:
- logging.exception(e)
- raise e
+ except Exception as err:
+ logging.exception(err)
+ raise err
model_kwargs = {}
if cfg.adapter == "qlora" and cfg.load_in_4bit:
@@ -155,7 +158,7 @@ def load_model(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
- except:
+ except Exception: # pylint: disable=broad-exception-caught
model_path = cfg.base_model
model, _ = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
@@ -210,13 +213,13 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
- trust_remote_code=True if cfg.trust_remote_code is True else False,
+ trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
base_model,
- trust_remote_code=True if cfg.trust_remote_code is True else False,
+ trust_remote_code=cfg.trust_remote_code or False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -225,30 +228,29 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
- trust_remote_code=True if cfg.trust_remote_code is True else False,
+ trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
- except Exception as e:
+ except Exception as err: # pylint: disable=broad-exception-caught
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
- logging.exception(e)
+ logging.exception(err)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
- trust_remote_code=True if cfg.trust_remote_code is True else False,
+ trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
model.resize_token_embeddings(embeddings_len)
- if (
- ((cfg.adapter == "lora" and load_in_8bit) or cfg.adapter == "qlora")
- and not cfg.gptq
- and (load_in_8bit or cfg.load_in_4bit)
+ if not cfg.gptq and (
+ (cfg.adapter == "lora" and load_in_8bit)
+ or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)
@@ -261,14 +263,14 @@ def load_model(
if cfg.gptq:
# Scales to half
logging.info("Fitting 4bit scales and zeros to half")
- for n, m in model.named_modules():
- if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
- type(m)
+ for _, module in model.named_modules():
+ if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
+ type(module)
):
- if hasattr(m, "is_v1_model") and m.is_v1_model:
- m.zeros = m.zeros.half()
- m.scales = m.scales.half()
- m.bias = m.bias.half()
+ if hasattr(module, "is_v1_model") and module.is_v1_model:
+ module.zeros = module.zeros.half()
+ module.scales = module.scales.half()
+ module.bias = module.bias.half()
if (
torch.cuda.device_count() > 1
@@ -278,8 +280,8 @@ def load_model(
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
- setattr(model, 'is_parallelizable', True)
- setattr(model, 'model_parallel', True)
+ setattr(model, "is_parallelizable", True)
+ setattr(model, "model_parallel", True)
requires_grad = []
for name, param in model.named_parameters(recurse=True):
@@ -308,11 +310,7 @@ def load_adapter(model, cfg, adapter):
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
- from peft import (
- AdaptionPromptConfig,
- get_peft_model,
- PeftModel,
- )
+ from peft import AdaptionPromptConfig, PeftModel, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -357,11 +355,7 @@ def find_all_linear_names(bits, model):
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
- from peft import (
- LoraConfig,
- get_peft_model,
- PeftModel,
- )
+ from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])
diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py
index b9b7e25be..f9b9e3583 100644
--- a/src/axolotl/utils/schedulers.py
+++ b/src/axolotl/utils/schedulers.py
@@ -1,7 +1,13 @@
+"""Module for custom LRScheduler class"""
+
from torch.optim.lr_scheduler import LRScheduler
class InterpolatingLogScheduler(LRScheduler):
+ """
+ A scheduler that interpolates learning rates in a logarithmic fashion
+ """
+
def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
"""A scheduler that interpolates learning rates in a logarithmic fashion
@@ -19,7 +25,9 @@ class InterpolatingLogScheduler(LRScheduler):
self.num_steps = num_steps
self.min_lr = min_lr
self.max_lr = max_lr
- self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
+ self.q = (max_lr / min_lr) ** ( # pylint: disable=invalid-name
+ 1 / (num_steps - 1)
+ )
super().__init__(optimizer, last_epoch)
def get_lr(self):
diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py
index f23ca8a92..1c535eb1b 100644
--- a/src/axolotl/utils/tokenization.py
+++ b/src/axolotl/utils/tokenization.py
@@ -1,6 +1,10 @@
-from termcolor import colored
+"""Module for tokenization utilities"""
+
+
import logging
+from termcolor import colored
+
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
@@ -17,7 +21,7 @@ def check_example_labels(example, tokenizer):
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = []
- for i, (input_id, label_id, mask) in enumerate(
+ for _, (input_id, label_id, mask) in enumerate(
zip(input_ids, labels, attention_mask)
):
decoded_input_token = tokenizer.decode(input_id)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 97b02baba..2986c491b 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -1,8 +1,11 @@
+"""Module containing the Trainer class and related functions"""
+
import importlib
import math
import os
import sys
from pathlib import Path
+from typing import Optional
import bitsandbytes as bnb
import torch.cuda
@@ -12,17 +15,26 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
-from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback
+from axolotl.utils.schedulers import InterpolatingLogScheduler
class OneCycleLRSchedulerTrainer(Trainer):
+ """
+ Trainer subclass that uses the OneCycleLR scheduler
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lr_scheduler = None
+
def create_scheduler(
- self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
+ self,
+ num_training_steps: int,
+ optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
- num_training_steps = num_training_steps
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
@@ -58,11 +70,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
- training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
+ training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
- if cfg.gradient_checkpointing is not None:
+ if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
@@ -112,13 +124,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
- load_best_model_at_end=True
- if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
- and cfg.val_set_size > 0
- and save_steps is not None
- and save_steps % eval_steps == 0
- and cfg.load_in_8bit is not True
- else False,
+ load_best_model_at_end=(
+ cfg.load_best_model_at_end is not False
+ and cfg.val_set_size > 0
+ and save_steps
+ and save_steps % eval_steps == 0
+ and cfg.load_in_8bit is not True
+ )
+ or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
@@ -140,7 +153,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
- and not "deepspeed" in training_arguments_kwargs
+ and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
@@ -206,7 +219,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
callbacks.append(early_stop_cb)
- if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
+ if cfg.local_rank == 0 and cfg.adapter in [
+ "lora",
+ "qlora",
+ ]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
data_collator_kwargs = {
diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py
index bc2940d5e..c4bc4f952 100644
--- a/src/axolotl/utils/validation.py
+++ b/src/axolotl/utils/validation.py
@@ -1,3 +1,5 @@
+"""Module for validating config files"""
+
import logging
@@ -38,7 +40,9 @@ def validate_config(cfg):
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
- raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
+ raise ValueError(
+ "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
+ )
# TODO
# MPT 7b
diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py
index 992bb1a5f..90e9c2f73 100644
--- a/src/axolotl/utils/wandb.py
+++ b/src/axolotl/utils/wandb.py
@@ -1,3 +1,5 @@
+"""Module for wandb utilities"""
+
import os
diff --git a/tests/fixtures/conversation.tokenized.json b/tests/fixtures/conversation.tokenized.json
index 5474624ad..0ac93713b 100644
--- a/tests/fixtures/conversation.tokenized.json
+++ b/tests/fixtures/conversation.tokenized.json
@@ -1 +1 @@
-{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2]}
\ No newline at end of file
+{"input_ids": [1, 319, 13563, 1546, 263, 12758, 1404, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, 6089, 304, 278, 1404, 29915, 29879, 5155, 29889, 3148, 1001, 29901, 920, 1033, 474, 2334, 263, 29086, 705, 11356, 5687, 393, 3667, 4637, 21531, 20159, 304, 4505, 1045, 3163, 29973, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 1033, 366, 2367, 592, 278, 330, 391, 310, 920, 372, 1033, 2466, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, 3148, 1001, 29901, 2367, 592, 263, 2702, 1342, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 10241, 474, 471, 4856, 411, 263, 6483, 8004, 310, 1716, 29892, 1033, 366, 5649, 278, 1021, 6964, 304, 592, 411, 393, 11833, 29973, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 723, 474, 437, 372, 411, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 29973, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, 3148, 1001, 29901, 920, 1033, 474, 2334, 445, 297, 3017, 29973, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2], "attention_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 739, 29915, 29879, 1950, 304, 671, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29892, 541, 372, 723, 5517, 367, 3755, 4280, 29889, 838, 5869, 293, 20159, 338, 263, 5443, 310, 23964, 393, 11898, 278, 4426, 310, 8162, 393, 526, 21634, 1090, 9126, 316, 689, 800, 29892, 1316, 408, 16116, 292, 322, 289, 2548, 29889, 512, 278, 3030, 310, 341, 1475, 705, 11356, 29892, 445, 1033, 367, 1304, 304, 27599, 278, 3829, 310, 278, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 2398, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 739, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 450, 6996, 2969, 723, 367, 304, 671, 21531, 20159, 304, 27599, 278, 3829, 310, 278, 341, 1475, 705, 11356, 3748, 7613, 322, 8161, 278, 14354, 310, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 910, 1033, 25135, 15783, 278, 3748, 7613, 408, 263, 25002, 2913, 322, 773, 25002, 15939, 1934, 1316, 408, 3632, 327, 2270, 6471, 322, 350, 9890, 3694, 304, 10115, 2472, 1048, 278, 4423, 310, 29086, 373, 278, 7613, 29889, 2398, 29892, 408, 306, 5276, 1434, 29892, 16049, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 577, 372, 29915, 29879, 451, 1554, 393, 508, 367, 10824, 297, 263, 2560, 1234, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1094, 263, 2702, 1342, 29892, 7755, 591, 505, 263, 341, 1475, 705, 11356, 3748, 7613, 411, 278, 1494, 5285, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 797, 445, 1206, 29892, 278, 3748, 7613, 508, 367, 9875, 408, 263, 25002, 2913, 29892, 988, 1269, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29889, 450, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 5293, 21531, 20159, 29892, 591, 508, 27599, 278, 3829, 310, 445, 2913, 322, 8161, 278, 14354, 310, 278, 29086, 2729, 373, 3058, 25002, 15939, 1934, 29889, 13, 13, 2831, 2777, 29892, 697, 982, 304, 437, 445, 338, 304, 10272, 278, 3632, 327, 2270, 6471, 310, 278, 2913, 29889, 15089, 327, 2270, 6471, 526, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 29892, 322, 896, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 512, 278, 1206, 310, 278, 7613, 2038, 29892, 591, 508, 10272, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 29892, 607, 16612, 278, 1353, 322, 376, 12181, 29908, 310, 278, 26532, 297, 278, 2913, 29889, 512, 445, 1206, 29892, 278, 937, 3632, 327, 2270, 2318, 723, 2649, 502, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29889, 13, 13, 29933, 1463, 373, 445, 2472, 29892, 591, 508, 10115, 393, 278, 29086, 297, 278, 3748, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29892, 408, 4318, 2400, 29901, 13, 28956, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 29871, 29896, 29871, 341, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 1577, 259, 29896, 13, 29871, 29896, 29871, 1577, 29871, 1577, 29871, 341, 259, 29896, 13, 29871, 29896, 259, 29896, 259, 29896, 259, 29896, 259, 29896, 13, 28956, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1342, 310, 920, 21531, 20159, 1033, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 526, 1784, 916, 1950, 13501, 29892, 322, 278, 2702, 4902, 310, 920, 304, 2334, 963, 723, 8839, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 1094, 306, 5276, 1434, 29892, 16049, 263, 1650, 763, 445, 723, 1996, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 18585, 29889, 17090, 366, 505, 263, 6483, 8004, 310, 1716, 21531, 20159, 322, 278, 3748, 310, 341, 1475, 705, 11356, 29892, 1244, 338, 920, 366, 1795, 671, 21531, 20159, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 4803, 25002, 15939, 1934, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1670, 526, 1784, 1950, 25002, 15939, 1934, 393, 1033, 367, 1304, 363, 445, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29941, 29889, 4803, 278, 2472, 10115, 1127, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 937, 3632, 327, 2270, 2318, 310, 278, 2913, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 21531, 20159, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 25002, 15939, 1934, 1641, 1304, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 21531, 20159, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 671, 28152, 3632, 3002, 470, 738, 310, 278, 916, 25002, 15939, 1934, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29892, 366, 723, 1101, 278, 1021, 6996, 6576, 408, 5439, 2038, 29892, 541, 411, 777, 12651, 297, 278, 2702, 4902, 29889, 13, 13, 2831, 2777, 29892, 1244, 338, 920, 366, 1795, 671, 28152, 3632, 3002, 304, 4505, 263, 341, 1475, 705, 11356, 3748, 7613, 29901, 13, 13, 29896, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 7806, 6862, 373, 278, 7613, 16161, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 508, 367, 2714, 310, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29892, 607, 526, 3291, 393, 526, 451, 760, 310, 278, 2913, 29889, 13, 29906, 29889, 11796, 29872, 278, 28152, 3632, 3002, 310, 278, 2913, 29889, 9034, 9696, 3632, 3002, 338, 263, 982, 310, 20766, 278, 26532, 297, 263, 2913, 975, 931, 29892, 322, 372, 508, 367, 1304, 304, 10115, 2472, 1048, 278, 14354, 310, 278, 29086, 373, 278, 341, 1475, 705, 11356, 7613, 29889, 13, 29941, 29889, 4803, 278, 2472, 515, 278, 28152, 3632, 3002, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 591, 508, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29946, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 508, 671, 445, 2472, 304, 4772, 963, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 773, 28152, 3632, 3002, 304, 4505, 341, 1475, 705, 11356, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 3748, 7613, 322, 278, 2702, 4902, 310, 278, 28152, 3632, 3002, 16287, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 28152, 3632, 3002, 508, 367, 1304, 304, 4505, 341, 1475, 705, 11356, 29889, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 319, 1799, 9047, 13566, 29901, 1763, 2334, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29892, 366, 723, 817, 304, 437, 278, 1494, 29901, 13, 13, 29896, 29889, 16052, 263, 5132, 3577, 363, 21531, 20159, 29889, 1670, 526, 3196, 3987, 3625, 29892, 1316, 408, 402, 566, 2918, 29892, 360, 291, 952, 375, 29892, 470, 4560, 7354, 29899, 29873, 1388, 29889, 4525, 9741, 3867, 14009, 322, 848, 12286, 363, 20602, 25002, 15939, 1934, 29892, 1316, 408, 3632, 327, 2270, 6471, 29892, 350, 9890, 3694, 29892, 470, 28152, 3632, 3002, 29889, 13, 29906, 29889, 16314, 278, 341, 1475, 705, 11356, 3748, 7613, 408, 263, 25002, 2913, 29889, 910, 1033, 367, 2309, 773, 263, 848, 3829, 4944, 491, 278, 21531, 20159, 3577, 29892, 1316, 408, 263, 3053, 506, 616, 4280, 470, 263, 478, 2035, 29367, 29899, 29934, 4512, 4280, 29889, 7806, 6862, 373, 278, 7613, 723, 3928, 304, 263, 1298, 297, 278, 2913, 29892, 322, 278, 29086, 373, 278, 7613, 723, 367, 9875, 408, 376, 5391, 267, 29908, 297, 278, 2913, 29889, 13, 29941, 29889, 11796, 29872, 278, 25002, 15939, 1934, 310, 278, 2913, 29889, 910, 1033, 367, 2309, 773, 14009, 4944, 491, 278, 21531, 20159, 3577, 29889, 1152, 1342, 29892, 366, 1033, 10272, 278, 28152, 3632, 3002, 310, 278, 2913, 773, 278, 28152, 3632, 3002, 5687, 29889, 13, 29946, 29889, 4803, 278, 2472, 515, 278, 25002, 15939, 1934, 304, 8161, 278, 14354, 310, 278, 29086, 373, 278, 7613, 29889, 910, 1033, 367, 2309, 773, 263, 10296, 310, 19475, 24481, 322, 8720, 29889, 1152, 1342, 29892, 565, 278, 28152, 3632, 3002, 14088, 393, 727, 526, 3023, 26532, 297, 278, 2913, 29892, 21050, 297, 263, 6862, 8267, 29892, 769, 366, 1033, 671, 263, 8424, 310, 775, 304, 10115, 393, 278, 29086, 526, 5982, 297, 278, 3023, 26995, 310, 278, 7613, 29889, 13, 29945, 29889, 4803, 445, 2472, 304, 4505, 278, 341, 1475, 705, 11356, 3748, 29889, 9038, 366, 1073, 278, 14354, 310, 278, 29086, 29892, 366, 1033, 671, 263, 8424, 310, 775, 304, 6336, 2828, 373, 278, 9109, 25256, 322, 2821, 278, 7613, 29889, 13, 13, 2776, 3236, 29892, 445, 338, 925, 697, 1950, 2948, 304, 16049, 263, 341, 1475, 705, 11356, 899, 369, 393, 3913, 21531, 20159, 297, 5132, 29889, 1670, 1122, 367, 916, 5837, 304, 437, 372, 29892, 8679, 373, 278, 2702, 4902, 310, 278, 5314, 29889, 2398, 29892, 445, 881, 2367, 366, 263, 2498, 2969, 310, 920, 304, 679, 4687, 411, 1316, 263, 2060, 29889, 2]}
diff --git a/tests/test_dict.py b/tests/test_dict.py
index 81a528fe4..4852707fb 100644
--- a/tests/test_dict.py
+++ b/tests/test_dict.py
@@ -1,3 +1,6 @@
+"""Module for testing DictDefault class"""
+
+
import unittest
import pytest
@@ -6,6 +9,10 @@ from axolotl.utils.dict import DictDefault
class DictDefaultTest(unittest.TestCase):
+ """
+ Test DictDefault class
+ """
+
def test_dict_default(self):
cfg = DictDefault(
{
@@ -41,7 +48,9 @@ class DictDefaultTest(unittest.TestCase):
}
)
- cfg = cfg | DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
+ cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
+ {"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
+ )
assert (
cfg.key_a.key_b == "value_b"
@@ -73,7 +82,7 @@ class DictDefaultTest(unittest.TestCase):
AttributeError,
match=r"'NoneType' object has no attribute 'another_random_key'",
):
- cfg.random_key.another_random_key
+ cfg.random_key.another_random_key = "value"
def test_dict_shorthand_assignment(self):
"""
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index 7595ffbe4..fa85fe5f6 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -1,3 +1,4 @@
+"""Module for testing prompt tokenizers."""
import json
import logging
import unittest
@@ -12,6 +13,10 @@ logging.basicConfig(level="INFO")
class TestPromptTokenizationStrategies(unittest.TestCase):
+ """
+ Test class for prompt tokenization strategies.
+ """
+
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
@@ -24,10 +29,15 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
def test_sharegpt_integration(self):
print(Path(__file__).parent)
- with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
+ with open(
+ Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
+ ) as fin:
data = fin.read()
conversation = json.loads(data)
- with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
+ with open(
+ Path(__file__).parent / "fixtures/conversation.tokenized.json",
+ encoding="utf-8",
+ ) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
diff --git a/tests/test_prompters.py b/tests/test_prompters.py
index 1c3c13852..11610ccc5 100644
--- a/tests/test_prompters.py
+++ b/tests/test_prompters.py
@@ -1,9 +1,15 @@
+"""Module testing prompters"""
+
import unittest
from axolotl.prompters import AlpacaPrompter, PromptStyle
class AlpacaPrompterTest(unittest.TestCase):
+ """
+ Test AlpacaPrompter
+ """
+
def test_prompt_style_w_none(self):
prompter = AlpacaPrompter(prompt_style=None)
res = next(prompter.build_prompt("tell me a joke"))
@@ -11,8 +17,10 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Instruction:" in res
def test_prompt_style_w_instruct(self):
- prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
- res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
+ res = next(
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
+ )
assert "Below is an instruction" in res
assert "### Instruction:" in res
assert "### Input:" in res
@@ -29,8 +37,10 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "ASSISTANT:" not in res
def test_prompt_style_w_chat(self):
- prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
- res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
+ prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
+ res = next(
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
+ )
assert "Below is an instruction" in res
assert "### Instruction:" not in res
assert "### Input:" not in res
@@ -45,5 +55,3 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res
-
-
diff --git a/tests/test_validation.py b/tests/test_validation.py
index af38eb6af..15bc07f84 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -1,12 +1,18 @@
+"""Module for testing the validation module"""
+
import unittest
import pytest
-from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
+from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):
+ """
+ Test the validation module
+ """
+
def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
@@ -24,7 +30,7 @@ class ValidationTest(unittest.TestCase):
}
)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
@@ -33,7 +39,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
@@ -42,7 +48,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": False,
}
@@ -51,7 +57,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
@@ -67,7 +73,7 @@ class ValidationTest(unittest.TestCase):
}
)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_8bit": True,
}
@@ -76,7 +82,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"gptq": True,
}
@@ -85,7 +91,7 @@ class ValidationTest(unittest.TestCase):
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault(
+ cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"load_in_4bit": True,
}
@@ -111,4 +117,3 @@ class ValidationTest(unittest.TestCase):
}
)
validate_config(cfg)
-