Lint and format

This commit is contained in:
NanoCode012
2023-05-29 03:45:42 +09:00
parent a98deb31a6
commit 392dfd9b07
9 changed files with 82 additions and 58 deletions

2
.gitignore vendored
View File

@@ -160,4 +160,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/ .idea/

View File

@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
pip3 install awscli && \ pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic pip3 install -U --no-cache-dir pydantic

View File

@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<" bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>" eos_token: "<|endoftext|>"

View File

@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<" bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>" eos_token: "<|endoftext|>"

View File

@@ -1,23 +1,39 @@
"""Module to convert json file to jsonl"""
import os import os
import sys import sys
from typing import Optional
from pathlib import Path from pathlib import Path
import fire import fire
from typing import Optional
from axolotl.convert import (
FileReader,
StdoutWriter,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
)
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
from axolotl.convert import *
def main( def main(
input: Path, file: Path,
output: Optional[Path] = None, output: Optional[Path] = None,
to_stdout: Optional[bool] = False, to_stdout: Optional[bool] = False,
): ):
"""
Convert a json file to jsonl
"""
file_reader = FileReader() file_reader = FileReader()
if to_stdout or output is None: if to_stdout or output is None:
writer = StdoutWriter() writer = StdoutWriter()
@@ -28,7 +44,7 @@ def main(
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer) converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(input, output) converter.convert(file, output)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,3 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib import importlib
import logging import logging
import os import os
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -37,7 +40,7 @@ def choose_device(cfg):
try: try:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
return "mps" return "mps"
except: except Exception: # pylint: disable=broad-exception-caught
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this # gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
generated = model.generate( generated = model.generate(
inputs=batch["input_ids"].to(cfg.device), inputs=batch["input_ids"].to(cfg.device),
do_sample=True, do_sample=True,
@@ -130,12 +133,12 @@ def train(
config = choose_config(config) config = choose_config(config)
# load the config from the yaml file # load the config from the yaml file
with open(config, "r") as f: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader)) cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
for k in kwargs: for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already # if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False: if k in cfg_keys or cfg.strict is False:
# handle booleans # handle booleans
@@ -167,13 +170,11 @@ def train(
# load the tokenizer first # load the tokenizer first
logging.info("loading tokenizer...") logging.info("loading tokenizer...")
tokenizer = load_tokenizer( tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
@@ -262,10 +263,13 @@ def train(
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# pylint: disable=fixme
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
# pylint: disable=fixme
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
else: else:
example_len = 0 example_len = 0
if ( if not example_len or (
not example_len buffer_len + int(add_concat_token) + example_len > self.seq_length
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
): ):
if buffer["input_ids"]: if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[ input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length : self.seq_length
] ]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if ( if labels.size() == input_ids.size() and (
labels.size() == input_ids.size() attention_mask.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
): ):
yield { yield {
"input_ids": input_ids, "input_ids": input_ids,

View File

@@ -1,14 +1,12 @@
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Union from typing import Tuple, Union
from datasets import ( from datasets import (
load_from_disk, load_from_disk,
load_dataset, load_dataset,
IterableDataset,
Dataset, Dataset,
concatenate_datasets,
DatasetDict, DatasetDict,
) )
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
md5( md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@" # noqa: W503
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) + "|".join( # noqa: W503
+ "|" sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
+ tokenizer_name )
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
) )
dataset = dataset["train"] dataset = dataset["train"]
except: except Exception: # pylint: disable=broad-except
pass pass
if dataset: if dataset:
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
fp = hf_hub_download( fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files repo_id=d.path, repo_type="dataset", filename=d.data_files
) )
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) ds: Dataset = load_dataset(
"json", data_files=fp, streaming=False, split=None
)
if not ds: if not ds:
raise Exception("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
num_shards=d.shards, index=0
)
else: else:
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) ds: Dataset = ds.shuffle(seed=42).shard(
num_shards=d.shards, index=0
)
d_type = d.type d_type = d.type
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
) -> (Dataset, Dataset): ) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -259,12 +265,14 @@ def load_prepare_datasets(
md5( md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@" # noqa: W503
+ str(max_packed_sequence_len) + str(max_packed_sequence_len) # noqa: W503
+ seed + seed # noqa: W503
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) + "|".join( # noqa: W503
+ "|" sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
+ tokenizer_name )
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
) )
@@ -285,7 +293,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
) )
dataset = dataset["train"] dataset = dataset["train"]
except: except Exception: # pylint: disable=broad-except
pass pass
if dataset: if dataset:
@@ -327,9 +335,9 @@ def load_prepare_datasets(
d d
for d in dataset for d in dataset
if len(d["input_ids"]) < cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 and len(d["input_ids"]) > 0 # noqa: W503
and len(d["input_ids"]) == len(d["attention_mask"]) and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
and len(d["input_ids"]) == len(d["labels"]) and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
] ]
) )

View File

@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
def test_prompt_style_w_instruct(self): def test_prompt_style_w_instruct(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res assert "Below is an instruction" in res
assert "### Instruction:" in res assert "### Instruction:" in res
assert "### Input:" in res assert "### Input:" in res
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
def test_prompt_style_w_chat(self): def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res assert "Below is an instruction" in res
assert "### Instruction:" not in res assert "### Instruction:" not in res
assert "### Input:" not in res assert "### Input:" not in res
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res assert "### Response:" not in res
assert "USER:" in res assert "USER:" in res
assert "ASSISTANT:" in res assert "ASSISTANT:" in res