From 37293dce07a36f31d3d7f7c2a39f238c8c2a29a0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 29 May 2023 18:48:58 +0900 Subject: [PATCH] Apply isort then black --- scripts/alpaca_json_to_jsonl.py | 7 +-- scripts/finetune.py | 20 +++--- setup.py | 2 +- src/axolotl/datasets.py | 8 ++- src/axolotl/flash_attn.py | 42 +++++++++---- src/axolotl/prompt_strategies/alpaca_chat.py | 1 + src/axolotl/prompt_strategies/creative_acr.py | 34 +++++++--- src/axolotl/prompt_tokenizers.py | 13 +++- src/axolotl/prompters.py | 8 ++- src/axolotl/utils/callbacks.py | 7 ++- src/axolotl/utils/data.py | 63 +++++++++++-------- src/axolotl/utils/models.py | 36 ++++------- src/axolotl/utils/tokenization.py | 1 + src/axolotl/utils/trainer.py | 11 +++- tests/test_validation.py | 2 +- 15 files changed, 158 insertions(+), 97 deletions(-) diff --git a/scripts/alpaca_json_to_jsonl.py b/scripts/alpaca_json_to_jsonl.py index 2f56c07b3..61cb170ec 100644 --- a/scripts/alpaca_json_to_jsonl.py +++ b/scripts/alpaca_json_to_jsonl.py @@ -2,23 +2,20 @@ import os import sys - -from typing import Optional, Union from pathlib import Path +from typing import Optional, Union import fire - from axolotl.convert import ( FileReader, - StdoutWriter, FileWriter, JsonlSerializer, JsonParser, JsonToJsonlConverter, + StdoutWriter, ) - # add src to the pythonpath so we don't need to pip install this project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") diff --git a/scripts/finetune.py b/scripts/finetune.py index 226068020..4716744b2 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -7,20 +7,20 @@ import random import signal import sys from pathlib import Path -from typing import Optional, List, Dict, Any, Union +from typing import Any, Dict, List, Optional, Union import fire import torch import yaml +from axolotl.utils.data import load_prepare_datasets +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer + # add src to the pythonpath so we don't need to pip install this from axolotl.utils.tokenization import check_dataset_labels -from axolotl.utils.validation import validate_config -from axolotl.utils.dict import DictDefault - -from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer +from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -242,7 +242,10 @@ def train( if cfg.local_rank == 0: signal.signal( signal.SIGINT, - lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)), + lambda signal, frame: ( + model.save_pretrained(cfg.output_dir), + sys.exit(0), + ), ) logging.info("Starting trainer...") @@ -255,7 +258,8 @@ def train( ] if len(possible_checkpoints) > 0: sorted_paths = sorted( - possible_checkpoints, key=lambda path: int(path.split("-")[-1]) + possible_checkpoints, + key=lambda path: int(path.split("-")[-1]), ) resume_from_checkpoint = sorted_paths[-1] logging.info( diff --git a/setup.py b/setup.py index 7f51f495f..de9fdc62f 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ """setup.py for axolotl""" -from setuptools import setup, find_packages +from setuptools import find_packages, setup install_requires = [] with open("./requirements.txt", encoding="utf-8") as requirements_file: diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 1e72be114..fb5e15656 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -5,8 +5,8 @@ from typing import List import torch from datasets import IterableDataset -from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException +from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded # lets use the concept of middlewares to wrap each dataset, for example @@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset): logging.warning( f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" ) - buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } buffer_len = 0 if example: diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index c7bd12c66..6df0b8e18 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -5,14 +5,11 @@ from typing import Optional, Tuple import torch - import transformers -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - from einops import rearrange - +from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb def forward( @@ -75,7 +72,11 @@ def forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = q_len cu_q_lens = torch.arange( - 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + 0, + (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=qkv.device, ) output = flash_attn_unpadded_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True @@ -88,25 +89,44 @@ def forward( x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange( - x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + x_unpad, + "nnz (three h d) -> nnz three h d", + three=3, + h=nheads, ) output_unpad = flash_attn_unpadded_qkvpacked_func( - x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + x_unpad, + cu_q_lens, + max_s, + 0.0, + softmax_scale=None, + causal=True, ) output = rearrange( pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + rearrange(output_unpad, "nnz h d -> nnz (h d)"), + indices, + bsz, + q_len, ), "b s (h d) -> b s h d", h=nheads, ) - return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None + return ( + self.o_proj(rearrange(output, "b s h d -> b s (h d)")), + None, + None, + ) # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length + self, + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, ): # pylint: disable=unused-argument # [bsz, seq_len] return attention_mask diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 29a0cb654..15dfb65c4 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,6 +1,7 @@ """Module containing the AlpacaQAPromptTokenizingStrategy class""" from typing import Tuple + from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, InstructionPromptTokenizingStrategy, diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py index 5cf89127d..ea67034b3 100644 --- a/src/axolotl/prompt_strategies/creative_acr.py +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -1,8 +1,9 @@ """Module loading the CreativePromptTokenizingStrategy and similar classes""" -from typing import Tuple, Union, Generator +from typing import Generator, Tuple, Union import yaml + from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy @@ -61,10 +62,14 @@ Answer: {answer} def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( - prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper + prompt["scores"], + default_flow_style=False, + Dumper=yaml.Dumper, ) critiques = yaml.dump( - prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper + prompt["critiques"], + default_flow_style=False, + Dumper=yaml.Dumper, ) evaluation = scores + critiques question = prompt["instruction"] @@ -97,10 +102,14 @@ Evaluation: def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: scores = yaml.dump( - prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper + prompt["scores"], + default_flow_style=False, + Dumper=yaml.Dumper, ) critiques = yaml.dump( - prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper + prompt["critiques"], + default_flow_style=False, + Dumper=yaml.Dumper, ) evaluation = scores + critiques question = prompt["instruction"] @@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase): def load_answer(tokenizer, cfg): return CreativeAnsweringPromptTokenizingStrategy( - CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeAnswerPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) def load_critique(tokenizer, cfg): return CreativeCritiquePromptTokenizingStrategy( - CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeCritiquePrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) def load_revise(tokenizer, cfg): return CreativeRevisePromptTokenizingStrategy( - CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + CreativeRevisePrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index d1655da32..3acae91b8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): part = part[0] + part[1] if not user_token else part[1] # this is still the user query, we should res = self._tokenize( - part.strip(), add_eos_token=False, strip_bos_token=True + part.strip(), + add_eos_token=False, + strip_bos_token=True, ) if user_token: res["input_ids"] = [user_token, *res["input_ids"]] @@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): part = part[0] + part[1] if not assistant_token else part[1] # this should be the assistent response, should end with an eos token res = self._tokenize( - part.strip(), add_eos_token=True, strip_bos_token=True + part.strip(), + add_eos_token=True, + strip_bos_token=True, ) if assistant_token: - res["input_ids"] = [assistant_token, *res["input_ids"]] + res["input_ids"] = [ + assistant_token, + *res["input_ids"], + ] # not masked out from labels labels = copy.deepcopy(res["input_ids"]) else: diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 97c2e3454..1a2535e19 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -2,8 +2,8 @@ import dataclasses import logging -from enum import auto, Enum -from typing import List, Optional, Union, Generator +from enum import Enum, auto +from typing import Generator, List, Optional, Union IGNORE_TOKEN_ID = -100 @@ -203,7 +203,9 @@ class ReflectAlpacaPrompter: res = self.prompt_no_input.format(instruction=instruction) if output and reflection and corrected: label = self.agent_label.format( - output=output, reflection=reflection, corrected=corrected + output=output, + reflection=reflection, + corrected=corrected, ) res = f"{res}{label}" yield res diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 70e83d6e4..f6852249a 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -4,9 +4,9 @@ import os from transformers import ( TrainerCallback, - TrainingArguments, - TrainerState, TrainerControl, + TrainerState, + TrainingArguments, ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- **kwargs, ): checkpoint_folder = os.path.join( - args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", ) peft_model_path = os.path.join(checkpoint_folder, "adapter_model") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 74812f9a0..c505cccfa 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -5,38 +5,33 @@ from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union -from datasets import ( - load_from_disk, - load_dataset, - Dataset, - DatasetDict, -) +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase -from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset +from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - GPTeacherPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, - AlpacaReflectionPTStrategy, - ShareGPTPromptTokenizingStrategy, - JeopardyPromptTokenizingStrategy, - CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy, + AlpacaPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + CompletionPromptTokenizingStrategy, + GPTeacherPromptTokenizingStrategy, + JeopardyPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy, ) from axolotl.prompters import ( AlpacaPrompter, + CompletionPrompter, GPTeacherPrompter, + JeopardyPrompter, + MultipleChoiceConcisePrompter, + MultipleChoiceExplainPrompter, ReflectAlpacaPrompter, ShareGPTPrompter, - JeopardyPrompter, - CompletionPrompter, - MultipleChoiceExplainPrompter, SummarizeTLDRPrompter, - MultipleChoiceConcisePrompter, ) @@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets( try: if cfg.push_dataset_to_hub: dataset = load_dataset( - f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token + f"{cfg.push_dataset_to_hub}/{ds_hash}", + use_auth_token=use_auth_token, ) dataset = dataset["train"] except Exception: # pylint: disable=broad-except @@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets( ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: - load_dataset(d.path, streaming=True, use_auth_token=use_auth_token) + load_dataset( + d.path, + streaming=True, + use_auth_token=use_auth_token, + ) ds_from_hub = True except FileNotFoundError: pass @@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): ds = load_dataset( - "json", data_files=d.path, streaming=False, split=None + "json", + data_files=d.path, + streaming=False, + split=None, ) elif ds_from_hub: if d.data_files: @@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets( ) else: ds = load_dataset( - d.path, streaming=False, use_auth_token=use_auth_token + d.path, + streaming=False, + use_auth_token=use_auth_token, ) else: fp = hf_hub_download( - repo_id=d.path, repo_type="dataset", filename=d.data_files + repo_id=d.path, + repo_type="dataset", + filename=d.data_files, ) ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: @@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets( def load_prepare_datasets( - tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path + tokenizer: PreTrainedTokenizerBase, + cfg, + default_dataset_prepared_path, ) -> Tuple[Dataset, Dataset]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len @@ -353,7 +362,8 @@ def load_prepare_datasets( f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" ) dataset.push_to_hub( - f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True + f"{cfg.push_dataset_to_hub}/{ds_hash}", + private=True, ) else: dataset = load_tokenized_prepared_datasets( @@ -365,7 +375,8 @@ def load_prepare_datasets( f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" ) dataset = dataset.shard( - num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx + num_shards=cfg.dataset_shard_num, + index=cfg.dataset_shard_idx, ) dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5cdfaab3c..8ce39b8bc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -5,23 +5,17 @@ import logging import math import os from pathlib import Path -from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401 +from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import ( # noqa: F401 - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - AutoConfig, - BitsAndBytesConfig, -) +from transformers import AutoModelForCausalLM # noqa: F401 +from transformers import PreTrainedModel # noqa: F401 +from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig try: - from transformers import ( - LlamaForCausalLM, - ) + from transformers import LlamaForCausalLM except ImportError: logging.warning( "This version of transformers does not support Llama. Consider upgrading." @@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: from peft import PeftConfig # noqa: F401 - from axolotl.utils.dict import DictDefault # noqa: F401 from transformers import PreTrainedTokenizer # noqa: F401 + from axolotl.utils.dict import DictDefault # noqa: F401 + def load_tokenizer( base_model_config, @@ -56,7 +51,10 @@ def load_tokenizer( logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: + if tokenizer.__class__.__name__ in [ + "LlamaTokenizer", + "LlamaTokenizerFast", + ]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": @@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter): def load_llama_adapter(model, cfg): # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import ( - AdaptionPromptConfig, - get_peft_model, - PeftModel, - ) + from peft import AdaptionPromptConfig, PeftModel, get_peft_model peft_config = AdaptionPromptConfig( adapter_layers=cfg.peft_adapter.layers, # layers (L) @@ -361,11 +355,7 @@ def find_all_linear_names(bits, model): def load_lora(model, cfg): # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import ( - LoraConfig, - get_peft_model, - PeftModel, - ) + from peft import LoraConfig, PeftModel, get_peft_model lora_target_modules = list(cfg.lora_target_modules or []) diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 159dbe15d..1c535eb1b 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -2,6 +2,7 @@ import logging + from termcolor import colored diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 45f13e530..4e41d1b61 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.schedulers import InterpolatingLogScheduler class OneCycleLRSchedulerTrainer(Trainer): @@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer): self.lr_scheduler = None def create_scheduler( - self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps) @@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ) callbacks.append(early_stop_cb) - if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 + if cfg.local_rank == 0 and cfg.adapter in [ + "lora", + "qlora", + ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) data_collator_kwargs = { diff --git a/tests/test_validation.py b/tests/test_validation.py index a92666001..15bc07f84 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -4,8 +4,8 @@ import unittest import pytest -from axolotl.utils.validation import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.validation import validate_config class ValidationTest(unittest.TestCase):