From e9650d3ae471551acdca4c53f8e920efd3aa5167 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:13:39 +0900 Subject: [PATCH] Fix mypy typing --- scripts/alpaca_json_to_jsonl.py | 3 +- scripts/extract_lora.py | 163 +++++++++++++++++++++ src/axolotl/prompt_strategies/pygmalion.py | 6 +- src/axolotl/prompt_tokenizers.py | 12 +- src/axolotl/prompters.py | 14 +- src/axolotl/utils/data.py | 20 +-- src/axolotl/utils/models.py | 2 +- src/axolotl/utils/trainer.py | 3 +- 8 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 scripts/extract_lora.py diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index f535d1afc..2f56c07b3 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -3,7 +3,7 @@ import os import sys -from typing import Optional +from typing import Optional, Union from pathlib import Path import fire @@ -35,6 +35,7 @@ def main( """ file_reader = FileReader() + writer: Union[StdoutWriter, FileWriter] if to_stdout or output is None: writer = StdoutWriter() else: diff --git a/scripts/extract_lora.py b/scripts/extract_lora.py new file mode 100644 index 000000000..be88c5705 --- /dev/null +++ b/scripts/extract_lora.py @@ -0,0 +1,163 @@ +# import logging +# import os +# import random +# import signal +# import sys +# from pathlib import Path + +# import fire +# import torch +# import yaml +# from addict import Dict + +# from peft import set_peft_model_state_dict, get_peft_model_state_dict + +# # add src to the pythonpath so we don't need to pip install this +# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +# src_dir = os.path.join(project_root, "src") +# sys.path.insert(0, src_dir) + +# from axolotl.utils.data import load_prepare_datasets +# from axolotl.utils.models import load_model +# from axolotl.utils.trainer import setup_trainer +# from axolotl.utils.wandb import setup_wandb_env_vars + +# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) + + +# def choose_device(cfg): +# def get_device(): +# if torch.cuda.is_available(): +# return "cuda" +# else: +# try: +# if torch.backends.mps.is_available(): +# return "mps" +# except: +# return "cpu" + +# cfg.device = get_device() +# if cfg.device == "cuda": +# cfg.device_map = {"": cfg.local_rank} +# else: +# cfg.device_map = {"": cfg.device} + + +# def choose_config(path: Path): +# yaml_files = [file for file in path.glob("*.yml")] + +# if not yaml_files: +# raise ValueError( +# "No YAML config files found in the specified directory. Are you using a .yml extension?" +# ) + +# print("Choose a YAML file:") +# for idx, file in enumerate(yaml_files): +# print(f"{idx + 1}. {file}") + +# chosen_file = None +# while chosen_file is None: +# try: +# choice = int(input("Enter the number of your choice: ")) +# if 1 <= choice <= len(yaml_files): +# chosen_file = yaml_files[choice - 1] +# else: +# print("Invalid choice. Please choose a number from the list.") +# except ValueError: +# print("Invalid input. Please enter a number.") + +# return chosen_file + + +# def save_latest_checkpoint_as_lora( +# config: Path = Path("configs/"), +# prepare_ds_only: bool = False, +# **kwargs, +# ): +# if Path(config).is_dir(): +# config = choose_config(config) + +# # load the config from the yaml file +# with open(config, "r") as f: +# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader)) +# # if there are any options passed in the cli, if it is something that seems valid from the yaml, +# # then overwrite the value +# cfg_keys = dict(cfg).keys() +# for k in kwargs: +# if k in cfg_keys: +# # handle booleans +# if isinstance(cfg[k], bool): +# cfg[k] = bool(kwargs[k]) +# else: +# cfg[k] = kwargs[k] + +# # setup some derived config / hyperparams +# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size +# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) +# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) +# assert cfg.local_rank == 0, "Run this with only one device!" + +# choose_device(cfg) +# cfg.ddp = False + +# if cfg.device == "mps": +# cfg.load_in_8bit = False +# cfg.tf32 = False +# if cfg.bf16: +# cfg.fp16 = True +# cfg.bf16 = False + +# # Load the model and tokenizer +# logging.info("loading model, tokenizer, and lora_config...") +# model, tokenizer, lora_config = load_model( +# cfg.base_model, +# cfg.base_model_config, +# cfg.model_type, +# cfg.tokenizer_type, +# cfg, +# adapter=cfg.adapter, +# inference=True, +# ) + +# model.config.use_cache = False + +# if torch.__version__ >= "2" and sys.platform != "win32": +# logging.info("Compiling torch model") +# model = torch.compile(model) + +# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")] +# if len(possible_checkpoints) > 0: +# sorted_paths = sorted( +# possible_checkpoints, key=lambda path: int(path.split("-")[-1]) +# ) +# resume_from_checkpoint = sorted_paths[-1] +# else: +# raise FileNotFoundError("Checkpoints folder not found") + +# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin") + +# assert os.path.exists(pytorch_bin_path), "Bin not found" + +# logging.info(f"Loading {pytorch_bin_path}") +# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu") + +# # d = get_peft_model_state_dict(model) +# print(model.load_state_dict(adapters_weights)) +# # with open('b.log', "w") as f: +# # f.write(str(d.keys())) +# assert False + +# print((adapters_weights.keys())) +# with open("a.log", "w") as f: +# f.write(str(adapters_weights.keys())) +# assert False + +# logging.info("Setting peft model state dict") +# set_peft_model_state_dict(model, adapters_weights) + +# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}") +# model.save_pretrained(cfg.output_dir) + + +# if __name__ == "__main__": +# fire.Fire(save_latest_checkpoint_as_lora) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 4cd9a1685..d38bc2beb 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -3,7 +3,7 @@ import copy import logging from collections import defaultdict -from typing import Generator +from typing import Generator, List, Tuple from axolotl.prompt_tokenizers import ( PromptTokenizingStrategy, @@ -19,7 +19,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for Pygmalion. """ - bot_prefix_token_ids = [] + bot_prefix_token_ids: List[int] = [] def __init__(self, prompter, tokenizer, *args, **kwargs): super().__init__(prompter, tokenizer, *args, **kwargs) @@ -88,7 +88,7 @@ class PygmalionPrompter: def build_prompt( self, source, *args, **kwargs # pylint: disable=unused-argument - ) -> Generator[str, None, None]: + ) -> Generator[Tuple[str, str], None, None]: for msg in source: yield msg["role"], msg["value"] diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 761441a7e..d1655da32 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -226,20 +226,16 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): Tokenizing strategy for Completion prompts. """ - def parse_instruction_fields(self, prompt) -> str: - return prompt["text"] - def tokenize_prompt(self, prompt): - instruction = self.parse_instruction_fields(prompt) - full_prompt = self._build_full_prompt(instruction, None, None) + full_prompt = self._build_full_prompt(prompt["text"], None, None) tokenized_full_prompt = self._tokenize(full_prompt) return tokenized_full_prompt def _build_full_prompt( self, instruction, input, response - ): # pylint: disable=unused-argument, redefined-builtin - return next(iter(self.prompter.build_prompt(instruction))) + ): # pylint: disable=redefined-builtin + return next(iter(self.prompter.build_prompt(instruction, input, response))) class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): @@ -419,7 +415,7 @@ def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: Returns the default values for the tokenize prompt function """ - result = { + result: Dict[str, List[int]] = { "input_ids": [], "attention_mask": [], "labels": [], diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index eced1d4a5..97c2e3454 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -3,7 +3,7 @@ import dataclasses import logging from enum import auto, Enum -from typing import List, Union, Generator +from typing import List, Optional, Union, Generator IGNORE_TOKEN_ID = -100 @@ -24,7 +24,7 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" - prompt_style = None + prompt_style: Optional[PromptStyle] = None def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value @@ -231,18 +231,18 @@ class Conversation: offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" - sep2: str = None + sep2: Optional[str] = None def get_prompt(self) -> Generator[str, None, None]: - seps = [self.sep, self.sep2] - preamble = self.system + seps[0] + # seps = [self.sep, self.sep2] + preamble = self.system + self.sep yield preamble for _, (role, message) in enumerate(self.messages): if message: - yield (role + ":", " " + message) + yield role + ":" + " " + message else: logging.warning(f"role with empty message: {role}") - yield (role + ":",) + yield role + ":" def copy(self): return Conversation( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 7b718bf56..74812f9a0 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import logging from hashlib import md5 from pathlib import Path -from typing import Tuple, Union +from typing import List, Tuple, Union from datasets import ( load_from_disk, @@ -95,40 +95,36 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): - ds: Dataset = load_dataset( + ds = load_dataset( "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds: Dataset = load_dataset( + ds = load_dataset( d.path, streaming=False, data_files=d.data_files, use_auth_token=use_auth_token, ) else: - ds: Dataset = load_dataset( + ds = load_dataset( d.path, streaming=False, use_auth_token=use_auth_token ) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds: Dataset = load_dataset( - "json", data_files=fp, streaming=False, split=None - ) + ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds: DatasetDict = ds.shuffle(seed=42)["train"].shard( + ds = ds.shuffle(seed=42)["train"].shard( num_shards=d.shards, index=0 ) else: - ds: Dataset = ds.shuffle(seed=42).shard( - num_shards=d.shards, index=0 - ) + ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] @@ -232,7 +228,7 @@ def load_tokenized_prepared_datasets( logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.info("tokenizing, merging, and shuffling master dataset") - samples = [] + samples: List[int] = [] for d in datasets: samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=42) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7a81b8a49..5cdfaab3c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -81,7 +81,7 @@ def load_model( adapter="lora", inference=False, ): - # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] """ Load a model from a base model and a model type. """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 299e39664..45f13e530 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import math import os import sys from pathlib import Path +from typing import Optional import bitsandbytes as bnb import torch.cuda @@ -28,7 +29,7 @@ class OneCycleLRSchedulerTrainer(Trainer): self.lr_scheduler = None def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps)