diff --git a/.gitignore b/.gitignore index 93a4f81b5..614a6676b 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ \ No newline at end of file +.idea/ diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index a61f6d42d..0ce43b621 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -99,4 +99,3 @@ RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \ pip3 install awscli && \ # The base image ships with `pydantic==1.8.2` which is not working pip3 install -U --no-cache-dir pydantic - diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 1291198cf..090cc6bcf 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -61,4 +61,3 @@ special_tokens: pad_token: "<|endoftext|>" bos_token: ">>ABSTRACT<<" eos_token: "<|endoftext|>" - diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 787c4121c..dc67d6125 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -61,4 +61,3 @@ special_tokens: pad_token: "<|endoftext|>" bos_token: ">>ABSTRACT<<" eos_token: "<|endoftext|>" - diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index 98c968309..f535d1afc 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -1,23 +1,39 @@ +"""Module to convert json file to jsonl""" + import os import sys + +from typing import Optional from pathlib import Path import fire -from typing import Optional + + +from axolotl.convert import ( + FileReader, + StdoutWriter, + FileWriter, + JsonlSerializer, + JsonParser, + JsonToJsonlConverter, +) + # add src to the pythonpath so we don't need to pip install this project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) -from axolotl.convert import * - def main( - input: Path, + file: Path, output: Optional[Path] = None, to_stdout: Optional[bool] = False, ): + """ + Convert a json file to jsonl + """ + file_reader = FileReader() if to_stdout or output is None: writer = StdoutWriter() @@ -28,7 +44,7 @@ def main( converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer) - converter.convert(input, output) + converter.convert(file, output) if __name__ == "__main__": diff --git a/scripts/finetune.py b/scripts/finetune.py index 58f1c0957..029e94648 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,3 +1,5 @@ +"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" + import importlib import logging import os @@ -16,15 +18,16 @@ from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.validation import validate_config from axolotl.utils.dict import DictDefault -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -src_dir = os.path.join(project_root, "src") -sys.path.insert(0, src_dir) - from axolotl.utils.data import load_prepare_datasets from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer from axolotl.utils.wandb import setup_wandb_env_vars +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +src_dir = os.path.join(project_root, "src") +sys.path.insert(0, src_dir) + + logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" @@ -37,7 +40,7 @@ def choose_device(cfg): try: if torch.backends.mps.is_available(): return "mps" - except: + except Exception: # pylint: disable=broad-exception-caught return "cpu" cfg.device = get_device() @@ -73,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): model.eval() with torch.no_grad(): - # gc = GenerationConfig() # TODO swap out and use this + # gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme generated = model.generate( inputs=batch["input_ids"].to(cfg.device), do_sample=True, @@ -130,12 +133,12 @@ def train( config = choose_config(config) # load the config from the yaml file - with open(config, "r") as f: - cfg: DictDefault = DictDefault(yaml.load(f, Loader=yaml.Loader)) + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader)) # if there are any options passed in the cli, if it is something that seems valid from the yaml, # then overwrite the value cfg_keys = cfg.keys() - for k in kwargs: + for k, _ in kwargs.items(): # if not strict, allow writing to cfg even if it's not in the yml already if k in cfg_keys or cfg.strict is False: # handle booleans @@ -167,13 +170,11 @@ def train( # load the tokenizer first logging.info("loading tokenizer...") - tokenizer = load_tokenizer( - cfg.base_model_config, - cfg.tokenizer_type, - cfg - ) + tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) - if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these + if check_not_in( + ["inference", "shard", "merge_lora"], kwargs + ): # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -262,10 +263,13 @@ def train( logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + # pylint: disable=fixme # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: model.save_pretrained(cfg.output_dir) + + # pylint: disable=fixme # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 0e166f6f0..c7bb9fbfe 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -82,10 +82,8 @@ class ConstantLengthDataset(IterableDataset): else: example_len = 0 - if ( - not example_len - or buffer_len + int(add_concat_token) + example_len - > self.seq_length + if not example_len or ( + buffer_len + int(add_concat_token) + example_len > self.seq_length ): if buffer["input_ids"]: input_ids = torch.cat(buffer["input_ids"], dim=-1)[ @@ -95,9 +93,8 @@ class ConstantLengthDataset(IterableDataset): : self.seq_length ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if ( - labels.size() == input_ids.size() - and attention_mask.size() == input_ids.size() + if labels.size() == input_ids.size() and ( + attention_mask.size() == input_ids.size() ): yield { "input_ids": input_ids, diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index a0cff21c4..6d2123eea 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,14 +1,12 @@ import logging from hashlib import md5 from pathlib import Path -from typing import Union +from typing import Tuple, Union from datasets import ( load_from_disk, load_dataset, - IterableDataset, Dataset, - concatenate_datasets, DatasetDict, ) from huggingface_hub import hf_hub_download @@ -48,10 +46,12 @@ def load_tokenized_prepared_datasets( md5( ( str(cfg.sequence_len) - + "@" - + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) - + "|" - + tokenizer_name + + "@" # noqa: W503 + + "|".join( # noqa: W503 + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + ) + + "|" # noqa: W503 + + tokenizer_name # noqa: W503 ).encode("utf-8") ).hexdigest() ) @@ -68,7 +68,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] - except: + except Exception: # pylint: disable=broad-except pass if dataset: @@ -109,15 +109,21 @@ def load_tokenized_prepared_datasets( fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) + ds: Dataset = load_dataset( + "json", data_files=fp, streaming=False, split=None + ) if not ds: - raise Exception("unhandled dataset load") + raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) + ds: DatasetDict = ds.shuffle(seed=42)["train"].shard( + num_shards=d.shards, index=0 + ) else: - ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) + ds: Dataset = ds.shuffle(seed=42).shard( + num_shards=d.shards, index=0 + ) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] @@ -243,7 +249,7 @@ def load_tokenized_prepared_datasets( def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path -) -> (Dataset, Dataset): +) -> Tuple[Dataset, Dataset]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -259,12 +265,14 @@ def load_prepare_datasets( md5( ( str(cfg.sequence_len) - + "@" - + str(max_packed_sequence_len) - + seed - + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) - + "|" - + tokenizer_name + + "@" # noqa: W503 + + str(max_packed_sequence_len) # noqa: W503 + + seed # noqa: W503 + + "|".join( # noqa: W503 + sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) + ) + + "|" # noqa: W503 + + tokenizer_name # noqa: W503 ).encode("utf-8") ).hexdigest() ) @@ -285,7 +293,7 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token ) dataset = dataset["train"] - except: + except Exception: # pylint: disable=broad-except pass if dataset: @@ -327,9 +335,9 @@ def load_prepare_datasets( d for d in dataset if len(d["input_ids"]) < cfg.sequence_len - and len(d["input_ids"]) > 0 - and len(d["input_ids"]) == len(d["attention_mask"]) - and len(d["input_ids"]) == len(d["labels"]) + and len(d["input_ids"]) > 0 # noqa: W503 + and len(d["input_ids"]) == len(d["attention_mask"]) # noqa: W503 + and len(d["input_ids"]) == len(d["labels"]) # noqa: W503 ] ) diff --git a/tests/test_prompters.py b/tests/test_prompters.py index 1c3c13852..b4a34c6c0 100644 --- a/tests/test_prompters.py +++ b/tests/test_prompters.py @@ -12,7 +12,9 @@ class AlpacaPrompterTest(unittest.TestCase): def test_prompt_style_w_instruct(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.instruct.value) - res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) assert "Below is an instruction" in res assert "### Instruction:" in res assert "### Input:" in res @@ -30,7 +32,9 @@ class AlpacaPrompterTest(unittest.TestCase): def test_prompt_style_w_chat(self): prompter = AlpacaPrompter(prompt_style=PromptStyle.chat.value) - res = next(prompter.build_prompt("tell me a joke about the following", "alpacas")) + res = next( + prompter.build_prompt("tell me a joke about the following", "alpacas") + ) assert "Below is an instruction" in res assert "### Instruction:" not in res assert "### Input:" not in res @@ -45,5 +49,3 @@ class AlpacaPrompterTest(unittest.TestCase): assert "### Response:" not in res assert "USER:" in res assert "ASSISTANT:" in res - -