Lint and format
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -160,4 +160,4 @@ cython_debug/
|
|||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
.idea/
|
.idea/
|
||||||
|
|||||||
@@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic
|
pip3 install -U --no-cache-dir pydantic
|
||||||
|
|
||||||
|
|||||||
@@ -61,4 +61,3 @@ special_tokens:
|
|||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
bos_token: ">>ABSTRACT<<"
|
bos_token: ">>ABSTRACT<<"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
|||||||
@@ -61,4 +61,3 @@ special_tokens:
|
|||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
bos_token: ">>ABSTRACT<<"
|
bos_token: ">>ABSTRACT<<"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,39 @@
|
|||||||
|
"""Module to convert json file to jsonl"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
from axolotl.convert import (
|
||||||
|
FileReader,
|
||||||
|
StdoutWriter,
|
||||||
|
FileWriter,
|
||||||
|
JsonlSerializer,
|
||||||
|
JsonParser,
|
||||||
|
JsonToJsonlConverter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
from axolotl.convert import *
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
input: Path,
|
file: Path,
|
||||||
output: Optional[Path] = None,
|
output: Optional[Path] = None,
|
||||||
to_stdout: Optional[bool] = False,
|
to_stdout: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Convert a json file to jsonl
|
||||||
|
"""
|
||||||
|
|
||||||
file_reader = FileReader()
|
file_reader = FileReader()
|
||||||
if to_stdout or output is None:
|
if to_stdout or output is None:
|
||||||
writer = StdoutWriter()
|
writer = StdoutWriter()
|
||||||
@@ -28,7 +44,7 @@ def main(
|
|||||||
|
|
||||||
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
|
||||||
|
|
||||||
converter.convert(input, output)
|
converter.convert(file, output)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|||||||
from axolotl.utils.validation import validate_config
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
src_dir = os.path.join(project_root, "src")
|
|
||||||
sys.path.insert(0, src_dir)
|
|
||||||
|
|
||||||
from axolotl.utils.data import load_prepare_datasets
|
from axolotl.utils.data import load_prepare_datasets
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
@@ -37,7 +40,7 @@ def choose_device(cfg):
|
|||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
return "mps"
|
return "mps"
|
||||||
except:
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
@@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# gc = GenerationConfig() # TODO swap out and use this
|
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
|
||||||
generated = model.generate(
|
generated = model.generate(
|
||||||
inputs=batch["input_ids"].to(cfg.device),
|
inputs=batch["input_ids"].to(cfg.device),
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
@@ -130,12 +133,12 @@ def train(
|
|||||||
config = choose_config(config)
|
config = choose_config(config)
|
||||||
|
|
||||||
# load the config from the yaml file
|
# load the config from the yaml file
|
||||||
with open(config, "r") as f:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader))
|
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
# then overwrite the value
|
# then overwrite the value
|
||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
for k in kwargs:
|
for k, _ in kwargs.items():
|
||||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||||
if k in cfg_keys or cfg.strict is False:
|
if k in cfg_keys or cfg.strict is False:
|
||||||
# handle booleans
|
# handle booleans
|
||||||
@@ -167,13 +170,11 @@ def train(
|
|||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
logging.info("loading tokenizer...")
|
logging.info("loading tokenizer...")
|
||||||
tokenizer = load_tokenizer(
|
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
|
||||||
cfg.base_model_config,
|
|
||||||
cfg.tokenizer_type,
|
|
||||||
cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
if check_not_in(
|
||||||
|
["inference", "shard", "merge_lora"], kwargs
|
||||||
|
): # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
@@ -262,10 +263,13 @@ def train(
|
|||||||
|
|
||||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
|
# pylint: disable=fixme
|
||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
# pylint: disable=fixme
|
||||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
else:
|
else:
|
||||||
example_len = 0
|
example_len = 0
|
||||||
|
|
||||||
if (
|
if not example_len or (
|
||||||
not example_len
|
buffer_len + int(add_concat_token) + example_len > self.seq_length
|
||||||
or buffer_len + int(add_concat_token) + example_len
|
|
||||||
> self.seq_length
|
|
||||||
):
|
):
|
||||||
if buffer["input_ids"]:
|
if buffer["input_ids"]:
|
||||||
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
|
||||||
@@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
: self.seq_length
|
: self.seq_length
|
||||||
]
|
]
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
if (
|
if labels.size() == input_ids.size() and (
|
||||||
labels.size() == input_ids.size()
|
attention_mask.size() == input_ids.size()
|
||||||
and attention_mask.size() == input_ids.size()
|
|
||||||
):
|
):
|
||||||
yield {
|
yield {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
IterableDataset,
|
|
||||||
Dataset,
|
Dataset,
|
||||||
concatenate_datasets,
|
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
@@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@" # noqa: W503
|
||||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
+ "|".join( # noqa: W503
|
||||||
+ "|"
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
+ tokenizer_name
|
)
|
||||||
|
+ "|" # noqa: W503
|
||||||
|
+ tokenizer_name # noqa: W503
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except:
|
except Exception: # pylint: disable=broad-except
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if dataset:
|
if dataset:
|
||||||
@@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets(
|
|||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||||
)
|
)
|
||||||
ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None)
|
ds: Dataset = load_dataset(
|
||||||
|
"json", data_files=fp, streaming=False, split=None
|
||||||
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
raise Exception("unhandled dataset load")
|
raise ValueError("unhandled dataset load")
|
||||||
# support for using a subset of the data
|
# support for using a subset of the data
|
||||||
if d.shards:
|
if d.shards:
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(
|
||||||
|
num_shards=d.shards, index=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
|
ds: Dataset = ds.shuffle(seed=42).shard(
|
||||||
|
num_shards=d.shards, index=0
|
||||||
|
)
|
||||||
d_type = d.type
|
d_type = d.type
|
||||||
d_type_split = d_type.split(":")
|
d_type_split = d_type.split(":")
|
||||||
d_base_type = d_type_split[0]
|
d_base_type = d_type_split[0]
|
||||||
@@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||||
) -> (Dataset, Dataset):
|
) -> Tuple[Dataset, Dataset]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -259,12 +265,14 @@ def load_prepare_datasets(
|
|||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@" # noqa: W503
|
||||||
+ str(max_packed_sequence_len)
|
+ str(max_packed_sequence_len) # noqa: W503
|
||||||
+ seed
|
+ seed # noqa: W503
|
||||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
+ "|".join( # noqa: W503
|
||||||
+ "|"
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
+ tokenizer_name
|
)
|
||||||
|
+ "|" # noqa: W503
|
||||||
|
+ tokenizer_name # noqa: W503
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
)
|
)
|
||||||
@@ -285,7 +293,7 @@ def load_prepare_datasets(
|
|||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except:
|
except Exception: # pylint: disable=broad-except
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if dataset:
|
if dataset:
|
||||||
@@ -327,9 +335,9 @@ def load_prepare_datasets(
|
|||||||
d
|
d
|
||||||
for d in dataset
|
for d in dataset
|
||||||
if len(d["input_ids"]) < cfg.sequence_len
|
if len(d["input_ids"]) < cfg.sequence_len
|
||||||
and len(d["input_ids"]) > 0
|
and len(d["input_ids"]) > 0 # noqa: W503
|
||||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503
|
||||||
and len(d["input_ids"]) == len(d["labels"])
|
and len(d["input_ids"]) == len(d["labels"]) # noqa: W503
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_prompt_style_w_instruct(self):
|
def test_prompt_style_w_instruct(self):
|
||||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value)
|
||||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
res = next(
|
||||||
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||||
|
)
|
||||||
assert "Below is an instruction" in res
|
assert "Below is an instruction" in res
|
||||||
assert "### Instruction:" in res
|
assert "### Instruction:" in res
|
||||||
assert "### Input:" in res
|
assert "### Input:" in res
|
||||||
@@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_prompt_style_w_chat(self):
|
def test_prompt_style_w_chat(self):
|
||||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value)
|
||||||
res = next(prompter.build_prompt("tell me a joke about the following", "alpacas"))
|
res = next(
|
||||||
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
||||||
|
)
|
||||||
assert "Below is an instruction" in res
|
assert "Below is an instruction" in res
|
||||||
assert "### Instruction:" not in res
|
assert "### Instruction:" not in res
|
||||||
assert "### Input:" not in res
|
assert "### Input:" not in res
|
||||||
@@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
assert "### Response:" not in res
|
assert "### Response:" not in res
|
||||||
assert "USER:" in res
|
assert "USER:" in res
|
||||||
assert "ASSISTANT:" in res
|
assert "ASSISTANT:" in res
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user