Apply isort then black

This commit is contained in:
NanoCode012
2023-05-29 18:48:58 +09:00
parent 96e8378692
commit 37293dce07
15 changed files with 158 additions and 97 deletions

View File

@@ -2,23 +2,20 @@
import os
import sys
from typing import Optional, Union
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
StdoutWriter,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")

View File

@@ -7,20 +7,20 @@ import random
import signal
import sys
from pathlib import Path
from typing import Optional, List, Dict, Any, Union
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -242,7 +242,10 @@ def train(
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)),
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
)
logging.info("Starting trainer...")
@@ -255,7 +258,8 @@ def train(
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(

View File

@@ -1,6 +1,6 @@
"""setup.py for axolotl"""
from setuptools import setup, find_packages
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:

View File

@@ -5,8 +5,8 @@ from typing import List
import torch
from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
logging.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
buffer_len = 0
if example:

View File

@@ -5,14 +5,11 @@
from typing import Optional, Tuple
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
@@ -75,7 +72,11 @@ def forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
@@ -88,25 +89,44 @@ def forward(
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask

View File

@@ -1,6 +1,7 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,

View File

@@ -1,8 +1,9 @@
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
from typing import Tuple, Union, Generator
from typing import Generator, Tuple, Union
import yaml
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
@@ -61,10 +62,14 @@ Answer: {answer}
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -97,10 +102,14 @@ Evaluation:
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
)
evaluation = scores + critiques
question = prompt["instruction"]
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
def load_answer(tokenizer, cfg):
return CreativeAnsweringPromptTokenizingStrategy(
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeAnswerPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_critique(tokenizer, cfg):
return CreativeCritiquePromptTokenizingStrategy(
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeCritiquePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_revise(tokenizer, cfg):
return CreativeRevisePromptTokenizingStrategy(
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
CreativeRevisePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=True
part.strip(),
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token
res = self._tokenize(
part.strip(), add_eos_token=True, strip_bos_token=True
part.strip(),
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [assistant_token, *res["input_ids"]]
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
else:

View File

@@ -2,8 +2,8 @@
import dataclasses
import logging
from enum import auto, Enum
from typing import List, Optional, Union, Generator
from enum import Enum, auto
from typing import Generator, List, Optional, Union
IGNORE_TOKEN_ID = -100
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(
output=output, reflection=reflection, corrected=corrected
output=output,
reflection=reflection,
corrected=corrected,
)
res = f"{res}{label}"
yield res

View File

@@ -4,9 +4,9 @@ import os
from transformers import (
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")

View File

@@ -5,38 +5,33 @@ from hashlib import md5
from pathlib import Path
from typing import List, Tuple, Union
from datasets import (
load_from_disk,
load_dataset,
Dataset,
DatasetDict,
)
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy,
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy,
)
from axolotl.prompters import (
AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter,
ShareGPTPrompter,
JeopardyPrompter,
CompletionPrompter,
MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter,
MultipleChoiceConcisePrompter,
)
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
try:
if cfg.push_dataset_to_hub:
dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
)
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False
try:
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
load_dataset(
d.path,
streaming=True,
use_auth_token=use_auth_token,
)
ds_from_hub = True
except FileNotFoundError:
pass
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
# prefer local dataset, even if hub exists
if Path(d.path).exists():
ds = load_dataset(
"json", data_files=d.path, streaming=False, split=None
"json",
data_files=d.path,
streaming=False,
split=None,
)
elif ds_from_hub:
if d.data_files:
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
)
else:
ds = load_dataset(
d.path, streaming=False, use_auth_token=use_auth_token
d.path,
streaming=False,
use_auth_token=use_auth_token,
)
else:
fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files
repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
)
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
if not ds:
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
@@ -353,7 +362,8 @@ def load_prepare_datasets(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
)
else:
dataset = load_tokenized_prepared_datasets(
@@ -365,7 +375,8 @@ def load_prepare_datasets(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
)
dataset = dataset.shard(
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
num_shards=cfg.dataset_shard_num,
index=cfg.dataset_shard_idx,
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)

View File

@@ -5,23 +5,17 @@ import logging
import math
import os
from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import torch
import transformers
from transformers import ( # noqa: F401
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
AutoConfig,
BitsAndBytesConfig,
)
from transformers import AutoModelForCausalLM # noqa: F401
from transformers import PreTrainedModel # noqa: F401
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
try:
from transformers import (
LlamaForCausalLM,
)
from transformers import LlamaForCausalLM
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
from transformers import PreTrainedTokenizer # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer(
base_model_config,
@@ -56,7 +51,10 @@ def load_tokenizer(
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
"LlamaTokenizerFast",
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import (
AdaptionPromptConfig,
get_peft_model,
PeftModel,
)
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import (
LoraConfig,
get_peft_model,
PeftModel,
)
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])

View File

@@ -2,6 +2,7 @@
import logging
from termcolor import colored

View File

@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.schedulers import InterpolatingLogScheduler
class OneCycleLRSchedulerTrainer(Trainer):
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
self.lr_scheduler = None
def create_scheduler(
self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
]: # only save in rank 0
callbacks.append(SavePeftModelCallback)
data_collator_kwargs = {

View File

@@ -4,8 +4,8 @@ import unittest
import pytest
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase):