Fix mypy typing

This commit is contained in:
NanoCode012
2023-05-29 18:13:39 +09:00
parent f1232b35ba
commit e9650d3ae4
8 changed files with 190 additions and 33 deletions

View File

@@ -3,7 +3,7 @@
import os import os
import sys import sys
from typing import Optional from typing import Optional, Union
from pathlib import Path from pathlib import Path
import fire import fire
@@ -35,6 +35,7 @@ def main(
""" """
file_reader = FileReader() file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None: if to_stdout or output is None:
writer = StdoutWriter() writer = StdoutWriter()
else: else:

163
scripts/extract_lora.py Normal file
View File

@@ -0,0 +1,163 @@
# import logging
# import os
# import random
# import signal
# import sys
# from pathlib import Path
# import fire
# import torch
# import yaml
# from addict import Dict
# from peft import set_peft_model_state_dict, get_peft_model_state_dict
# # add src to the pythonpath so we don't need to pip install this
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# src_dir = os.path.join(project_root, "src")
# sys.path.insert(0, src_dir)
# from axolotl.utils.data import load_prepare_datasets
# from axolotl.utils.models import load_model
# from axolotl.utils.trainer import setup_trainer
# from axolotl.utils.wandb import setup_wandb_env_vars
# logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
# def choose_device(cfg):
# def get_device():
# if torch.cuda.is_available():
# return "cuda"
# else:
# try:
# if torch.backends.mps.is_available():
# return "mps"
# except:
# return "cpu"
# cfg.device = get_device()
# if cfg.device == "cuda":
# cfg.device_map = {"": cfg.local_rank}
# else:
# cfg.device_map = {"": cfg.device}
# def choose_config(path: Path):
# yaml_files = [file for file in path.glob("*.yml")]
# if not yaml_files:
# raise ValueError(
# "No YAML config files found in the specified directory. Are you using a .yml extension?"
# )
# print("Choose a YAML file:")
# for idx, file in enumerate(yaml_files):
# print(f"{idx + 1}. {file}")
# chosen_file = None
# while chosen_file is None:
# try:
# choice = int(input("Enter the number of your choice: "))
# if 1 <= choice <= len(yaml_files):
# chosen_file = yaml_files[choice - 1]
# else:
# print("Invalid choice. Please choose a number from the list.")
# except ValueError:
# print("Invalid input. Please enter a number.")
# return chosen_file
# def save_latest_checkpoint_as_lora(
# config: Path = Path("configs/"),
# prepare_ds_only: bool = False,
# **kwargs,
# ):
# if Path(config).is_dir():
# config = choose_config(config)
# # load the config from the yaml file
# with open(config, "r") as f:
# cfg: Dict = Dict(lambda: None, yaml.load(f, Loader=yaml.Loader))
# # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# # then overwrite the value
# cfg_keys = dict(cfg).keys()
# for k in kwargs:
# if k in cfg_keys:
# # handle booleans
# if isinstance(cfg[k], bool):
# cfg[k] = bool(kwargs[k])
# else:
# cfg[k] = kwargs[k]
# # setup some derived config / hyperparams
# cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
# cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
# cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
# assert cfg.local_rank == 0, "Run this with only one device!"
# choose_device(cfg)
# cfg.ddp = False
# if cfg.device == "mps":
# cfg.load_in_8bit = False
# cfg.tf32 = False
# if cfg.bf16:
# cfg.fp16 = True
# cfg.bf16 = False
# # Load the model and tokenizer
# logging.info("loading model, tokenizer, and lora_config...")
# model, tokenizer, lora_config = load_model(
# cfg.base_model,
# cfg.base_model_config,
# cfg.model_type,
# cfg.tokenizer_type,
# cfg,
# adapter=cfg.adapter,
# inference=True,
# )
# model.config.use_cache = False
# if torch.__version__ >= "2" and sys.platform != "win32":
# logging.info("Compiling torch model")
# model = torch.compile(model)
# possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
# if len(possible_checkpoints) > 0:
# sorted_paths = sorted(
# possible_checkpoints, key=lambda path: int(path.split("-")[-1])
# )
# resume_from_checkpoint = sorted_paths[-1]
# else:
# raise FileNotFoundError("Checkpoints folder not found")
# pytorch_bin_path = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
# assert os.path.exists(pytorch_bin_path), "Bin not found"
# logging.info(f"Loading {pytorch_bin_path}")
# adapters_weights = torch.load(pytorch_bin_path, map_location="cpu")
# # d = get_peft_model_state_dict(model)
# print(model.load_state_dict(adapters_weights))
# # with open('b.log', "w") as f:
# # f.write(str(d.keys()))
# assert False
# print((adapters_weights.keys()))
# with open("a.log", "w") as f:
# f.write(str(adapters_weights.keys()))
# assert False
# logging.info("Setting peft model state dict")
# set_peft_model_state_dict(model, adapters_weights)
# logging.info(f"Set Completed!!! Saving pre-trained model to {cfg.output_dir}")
# model.save_pretrained(cfg.output_dir)
# if __name__ == "__main__":
# fire.Fire(save_latest_checkpoint_as_lora)

View File

@@ -3,7 +3,7 @@
import copy import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Generator from typing import Generator, List, Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
PromptTokenizingStrategy, PromptTokenizingStrategy,
@@ -19,7 +19,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for Pygmalion. Tokenizing strategy for Pygmalion.
""" """
bot_prefix_token_ids = [] bot_prefix_token_ids: List[int] = []
def __init__(self, prompter, tokenizer, *args, **kwargs): def __init__(self, prompter, tokenizer, *args, **kwargs):
super().__init__(prompter, tokenizer, *args, **kwargs) super().__init__(prompter, tokenizer, *args, **kwargs)
@@ -88,7 +88,7 @@ class PygmalionPrompter:
def build_prompt( def build_prompt(
self, source, *args, **kwargs # pylint: disable=unused-argument self, source, *args, **kwargs # pylint: disable=unused-argument
) -> Generator[str, None, None]: ) -> Generator[Tuple[str, str], None, None]:
for msg in source: for msg in source:
yield msg["role"], msg["value"] yield msg["role"], msg["value"]

View File

@@ -226,20 +226,16 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Tokenizing strategy for Completion prompts. Tokenizing strategy for Completion prompts.
""" """
def parse_instruction_fields(self, prompt) -> str:
return prompt["text"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
instruction = self.parse_instruction_fields(prompt) full_prompt = self._build_full_prompt(prompt["text"], None, None)
full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt return tokenized_full_prompt
def _build_full_prompt( def _build_full_prompt(
self, instruction, input, response self, instruction, input, response
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=redefined-builtin
return next(iter(self.prompter.build_prompt(instruction))) return next(iter(self.prompter.build_prompt(instruction, input, response)))
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -419,7 +415,7 @@ def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
Returns the default values for the tokenize prompt function Returns the default values for the tokenize prompt function
""" """
result = { result: Dict[str, List[int]] = {
"input_ids": [], "input_ids": [],
"attention_mask": [], "attention_mask": [],
"labels": [], "labels": [],

View File

@@ -3,7 +3,7 @@
import dataclasses import dataclasses
import logging import logging
from enum import auto, Enum from enum import auto, Enum
from typing import List, Union, Generator from typing import List, Optional, Union, Generator
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -24,7 +24,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
prompt_style = None prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value): def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
@@ -231,18 +231,18 @@ class Conversation:
offset: int offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###" sep: str = "###"
sep2: str = None sep2: Optional[str] = None
def get_prompt(self) -> Generator[str, None, None]: def get_prompt(self) -> Generator[str, None, None]:
seps = [self.sep, self.sep2] # seps = [self.sep, self.sep2]
preamble = self.system + seps[0] preamble = self.system + self.sep
yield preamble yield preamble
for _, (role, message) in enumerate(self.messages): for _, (role, message) in enumerate(self.messages):
if message: if message:
yield (role + ":", " " + message) yield role + ":" + " " + message
else: else:
logging.warning(f"role with empty message: {role}") logging.warning(f"role with empty message: {role}")
yield (role + ":",) yield role + ":"
def copy(self): def copy(self):
return Conversation( return Conversation(

View File

@@ -3,7 +3,7 @@
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import List, Tuple, Union
from datasets import ( from datasets import (
load_from_disk, load_from_disk,
@@ -95,40 +95,36 @@ def load_tokenized_prepared_datasets(
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
if Path(d.path).exists(): if Path(d.path).exists():
ds: Dataset = load_dataset( ds = load_dataset(
"json", data_files=d.path, streaming=False, split=None "json", data_files=d.path, streaming=False, split=None
) )
elif ds_from_hub: elif ds_from_hub:
if d.data_files: if d.data_files:
ds: Dataset = load_dataset( ds = load_dataset(
d.path, d.path,
streaming=False, streaming=False,
data_files=d.data_files, data_files=d.data_files,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
else: else:
ds: Dataset = load_dataset( ds = load_dataset(
d.path, streaming=False, use_auth_token=use_auth_token d.path, streaming=False, use_auth_token=use_auth_token
) )
else: else:
fp = hf_hub_download( fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files repo_id=d.path, repo_type="dataset", filename=d.data_files
) )
ds: Dataset = load_dataset( ds = load_dataset("json", data_files=fp, streaming=False, split=None)
"json", data_files=fp, streaming=False, split=None
)
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds: DatasetDict = ds.shuffle(seed=42)["train"].shard( ds = ds.shuffle(seed=42)["train"].shard(
num_shards=d.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
ds: Dataset = ds.shuffle(seed=42).shard( ds = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0)
num_shards=d.shards, index=0
)
d_type = d.type d_type = d.type
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -232,7 +228,7 @@ def load_tokenized_prepared_datasets(
logging.error(f"unhandled prompt tokenization strategy: {d.type}") logging.error(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset") logging.info("tokenizing, merging, and shuffling master dataset")
samples = [] samples: List[int] = []
for d in datasets: for d in datasets:
samples = samples + list(d) samples = samples + list(d)
dataset = Dataset.from_list(samples).shuffle(seed=42) dataset = Dataset.from_list(samples).shuffle(seed=42)

View File

@@ -81,7 +81,7 @@ def load_model(
adapter="lora", adapter="lora",
inference=False, inference=False,
): ):
# type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] # type: (str, str, str, str, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
""" """
Load a model from a base model and a model type. Load a model from a base model and a model type.
""" """

View File

@@ -5,6 +5,7 @@ import math
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import torch.cuda import torch.cuda
@@ -28,7 +29,7 @@ class OneCycleLRSchedulerTrainer(Trainer):
self.lr_scheduler = None self.lr_scheduler = None
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
): ):
optimizer = self.optimizer if optimizer is None else optimizer optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps) num_warmup_steps = self.args.get_warmup_steps(num_training_steps)