Lint and format

This commit is contained in:
NanoCode012
2023-05-29 03:45:42 +09:00
parent a98deb31a6
commit 392dfd9b07
9 changed files with 82 additions and 58 deletions

2
.gitignore vendored
View File

@@ -160,4 +160,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.idea/

View File

@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic

View File

@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -61,4 +61,3 @@ special_tokens:
pad_token: "<|endoftext|>"
bos_token: ">>ABSTRACT<<"
eos_token: "<|endoftext|>"

View File

@@ -1,23 +1,39 @@
"""Module to convert json file to jsonl"""
import os
import sys
from typing import Optional
from pathlib import Path
import fire
from typing import Optional
from axolotl.convert import (
FileReader,
StdoutWriter,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.convert import *
def main(
input: Path,
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
if to_stdout or output is None:
writer = StdoutWriter()
@@ -28,7 +44,7 @@ def main(
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(input, output)
converter.convert(file, output)
if __name__ == "__main__":

View File

@@ -1,3 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -37,7 +40,7 @@ def choose_device(cfg):
try:
if torch.backends.mps.is_available():
return "mps"
except:
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
@@ -130,12 +133,12 @@ def train(
config = choose_config(config)
# load the config from the yaml file
with open(config, "r") as f:
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k in kwargs:
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
# handle booleans
@@ -167,13 +170,11 @@ def train(
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(
cfg.base_model_config,
cfg.tokenizer_type,
cfg
)
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -262,10 +263,13 @@ def train(
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# pylint: disable=fixme
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
# pylint: disable=fixme
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time

View File

@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
else:
example_len = 0
if (
not example_len
or buffer_len + int(add_concat_token) + example_len
> self.seq_length
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if (
labels.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,

View File

@@ -1,14 +1,12 @@
import logging
from hashlib import md5
from pathlib import Path
from typing import Union
from typing import Tuple, Union
from datasets import (
load_from_disk,
load_dataset,
IterableDataset,
Dataset,
concatenate_datasets,
DatasetDict,
)
from huggingface_hub import hf_hub_download
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
md5(
(
str(cfg.sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
+ "@" # noqa: W503
+ "|".join( # noqa: W503
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8")
).hexdigest()
)
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except
pass
if dataset:
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files
)
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
ds: Dataset = load_dataset(
"json", data_files=fp, streaming=False, split=None
)
if not ds:
raise Exception("unhandled dataset load")
raise ValueError("unhandled dataset load")
# support for using a subset of the data
if d.shards:
if "train" in ds:
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
num_shards=d.shards, index=0
)
else:
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
ds: Dataset = ds.shuffle(seed=42).shard(
num_shards=d.shards, index=0
)
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
) -> (Dataset, Dataset):
) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
@@ -259,12 +265,14 @@ def load_prepare_datasets(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
+ "@" # noqa: W503
+ str(max_packed_sequence_len) # noqa: W503
+ seed # noqa: W503
+ "|".join( # noqa: W503
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
)
+ "|" # noqa: W503
+ tokenizer_name # noqa: W503
).encode("utf-8")
).hexdigest()
)
@@ -285,7 +293,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
)
dataset = dataset["train"]
except:
except Exception: # pylint: disable=broad-except
pass
if dataset:
@@ -327,9 +335,9 @@ def load_prepare_datasets(
d
for d in dataset
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
and len(d["input_ids"]) > 0 # noqa: W503
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
]
)

View File

@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
def test_prompt_style_w_instruct(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res
assert "### Instruction:" in res
assert "### Input:" in res
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
def test_prompt_style_w_chat(self):
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
res = next(
prompter.build_prompt("tell me a joke about the following", "alpacas")
)
assert "Below is an instruction" in res
assert "### Instruction:" not in res
assert "### Input:" not in res
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
assert "### Response:" not in res
assert "USER:" in res
assert "ASSISTANT:" in res