Lint and format
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -160,4 +160,4 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
.idea/
|
||||
|
||||
@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic
|
||||
|
||||
|
||||
@@ -61,4 +61,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
@@ -61,4 +61,3 @@ special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
bos_token: ">>ABSTRACT<<"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
|
||||
@@ -1,23 +1,39 @@
|
||||
"""Module to convert json file to jsonl"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
StdoutWriter,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
)
|
||||
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.convert import *
|
||||
|
||||
|
||||
def main(
|
||||
input: Path,
|
||||
file: Path,
|
||||
output: Optional[Path] = None,
|
||||
to_stdout: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
Convert a json file to jsonl
|
||||
"""
|
||||
|
||||
file_reader = FileReader()
|
||||
if to_stdout or output is None:
|
||||
writer = StdoutWriter()
|
||||
@@ -28,7 +44,7 @@ def main(
|
||||
|
||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||
|
||||
converter.convert(input, output)
|
||||
converter.convert(file, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
|
||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
@@ -37,7 +40,7 @@ def choose_device(cfg):
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# gc = GenerationConfig() # TODO swap out and use this
|
||||
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
|
||||
generated = model.generate(
|
||||
inputs=batch["input_ids"].to(cfg.device),
|
||||
do_sample=True,
|
||||
@@ -130,12 +133,12 @@ def train(
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, "r") as f:
|
||||
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k in kwargs:
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
# handle booleans
|
||||
@@ -167,13 +170,11 @@ def train(
|
||||
|
||||
# load the tokenizer first
|
||||
logging.info("loading tokenizer...")
|
||||
tokenizer = load_tokenizer(
|
||||
cfg.base_model_config,
|
||||
cfg.tokenizer_type,
|
||||
cfg
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
||||
if check_not_in(
|
||||
["inference", "shard", "merge_lora"], kwargs
|
||||
): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
@@ -262,10 +263,13 @@ def train(
|
||||
|
||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# pylint: disable=fixme
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
# pylint: disable=fixme
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
|
||||
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
example_len = 0
|
||||
|
||||
if (
|
||||
not example_len
|
||||
or buffer_len + int(add_concat_token) + example_len
|
||||
> self.seq_length
|
||||
if not example_len or (
|
||||
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||
):
|
||||
if buffer["input_ids"]:
|
||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
||||
: self.seq_length
|
||||
]
|
||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||
if (
|
||||
labels.size() == input_ids.size()
|
||||
and attention_mask.size() == input_ids.size()
|
||||
if labels.size() == input_ids.size() and (
|
||||
attention_mask.size() == input_ids.size()
|
||||
):
|
||||
yield {
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
DatasetDict,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
+ "@" # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
)
|
||||
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
ds: Dataset = load_dataset(
|
||||
"json", data_files=fp, streaming=False, split=None
|
||||
)
|
||||
if not ds:
|
||||
raise Exception("unhandled dataset load")
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
||||
ds: Dataset = ds.shuffle(seed=42).shard(
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||
) -> (Dataset, Dataset):
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -259,12 +265,14 @@ def load_prepare_datasets(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
+ "@" # noqa: W503
|
||||
+ str(max_packed_sequence_len) # noqa: W503
|
||||
+ seed # noqa: W503
|
||||
+ "|".join( # noqa: W503
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|" # noqa: W503
|
||||
+ tokenizer_name # noqa: W503
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
@@ -285,7 +293,7 @@ def load_prepare_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
|
||||
if dataset:
|
||||
@@ -327,9 +335,9 @@ def load_prepare_datasets(
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
and len(d["input_ids"]) > 0 # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
||||
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
|
||||
def test_prompt_style_w_instruct(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "Below is an instruction" in res
|
||||
assert "### Instruction:" in res
|
||||
assert "### Input:" in res
|
||||
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
|
||||
def test_prompt_style_w_chat(self):
|
||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
||||
res = next(
|
||||
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||
)
|
||||
assert "Below is an instruction" in res
|
||||
assert "### Instruction:" not in res
|
||||
assert "### Input:" not in res
|
||||
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
|
||||
assert "### Response:" not in res
|
||||
assert "USER:" in res
|
||||
assert "ASSISTANT:" in res
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user