Apply isort then black

This commit is contained in:
NanoCode012
2023-05-29 18:48:58 +09:00
parent 96e8378692
commit 37293dce07
15 changed files with 158 additions and 97 deletions

View File

@@ -2,23 +2,20 @@
import os import os
import sys import sys
from typing import Optional, Union
from pathlib import Path from pathlib import Path
from typing import Optional, Union
import fire import fire
from axolotl.convert import ( from axolotl.convert import (
FileReader, FileReader,
StdoutWriter,
FileWriter, FileWriter,
JsonlSerializer, JsonlSerializer,
JsonParser, JsonParser,
JsonToJsonlConverter, JsonToJsonlConverter,
StdoutWriter,
) )
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")

View File

@@ -7,20 +7,20 @@ import random
import signal import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Any, Union from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch import torch
import yaml import yaml
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -242,7 +242,10 @@ def train(
if cfg.local_rank == 0: if cfg.local_rank == 0:
signal.signal( signal.signal(
signal.SIGINT, signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)), lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
) )
logging.info("Starting trainer...") logging.info("Starting trainer...")
@@ -255,7 +258,8 @@ def train(
] ]
if len(possible_checkpoints) > 0: if len(possible_checkpoints) > 0:
sorted_paths = sorted( sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1]) possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
) )
resume_from_checkpoint = sorted_paths[-1] resume_from_checkpoint = sorted_paths[-1]
logging.info( logging.info(

View File

@@ -1,6 +1,6 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
from setuptools import setup, find_packages from setuptools import find_packages, setup
install_requires = [] install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file: with open("./requirements.txt", encoding="utf-8") as requirements_file:

View File

@@ -5,8 +5,8 @@ from typing import List
import torch import torch
from datasets import IterableDataset from datasets import IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example # lets use the concept of middlewares to wrap each dataset, for example
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
logging.warning( logging.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
) )
buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
buffer_len = 0 buffer_len = 0
if example: if example:

View File

@@ -5,14 +5,11 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import transformers import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward( def forward(
@@ -75,7 +72,11 @@ def forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len max_s = q_len
cu_q_lens = torch.arange( cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
) )
output = flash_attn_unpadded_qkvpacked_func( output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
@@ -88,25 +89,44 @@ def forward(
x = rearrange(qkv, "b s three h d -> b s (three h d)") x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange( x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
) )
output_unpad = flash_attn_unpadded_qkvpacked_func( output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
) )
output = rearrange( output = rearrange(
pad_input( pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
), ),
"b s (h d) -> b s h d", "b s (h d) -> b s h d",
h=nheads, h=nheads,
) )
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
return attention_mask return attention_mask

View File

@@ -1,6 +1,7 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class""" """Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy, InstructionPromptTokenizingStrategy,

View File

@@ -1,8 +1,9 @@
"""Module loading the CreativePromptTokenizingStrategy and similar classes""" """Module loading the CreativePromptTokenizingStrategy and similar classes"""
from typing import Tuple, Union, Generator from typing import Generator, Tuple, Union
import yaml import yaml
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
@@ -61,10 +62,14 @@ Answer: {answer}
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump( scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
) )
critiques = yaml.dump( critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
) )
evaluation = scores + critiques evaluation = scores + critiques
question = prompt["instruction"] question = prompt["instruction"]
@@ -97,10 +102,14 @@ Evaluation:
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
scores = yaml.dump( scores = yaml.dump(
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper prompt["scores"],
default_flow_style=False,
Dumper=yaml.Dumper,
) )
critiques = yaml.dump( critiques = yaml.dump(
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper prompt["critiques"],
default_flow_style=False,
Dumper=yaml.Dumper,
) )
evaluation = scores + critiques evaluation = scores + critiques
question = prompt["instruction"] question = prompt["instruction"]
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
def load_answer(tokenizer, cfg): def load_answer(tokenizer, cfg):
return CreativeAnsweringPromptTokenizingStrategy( return CreativeAnsweringPromptTokenizingStrategy(
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len CreativeAnswerPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def load_critique(tokenizer, cfg): def load_critique(tokenizer, cfg):
return CreativeCritiquePromptTokenizingStrategy( return CreativeCritiquePromptTokenizingStrategy(
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len CreativeCritiquePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )
def load_revise(tokenizer, cfg): def load_revise(tokenizer, cfg):
return CreativeRevisePromptTokenizingStrategy( return CreativeRevisePromptTokenizingStrategy(
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len CreativeRevisePrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
) )

View File

@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not user_token else part[1] part = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should # this is still the user query, we should
res = self._tokenize( res = self._tokenize(
part.strip(), add_eos_token=False, strip_bos_token=True part.strip(),
add_eos_token=False,
strip_bos_token=True,
) )
if user_token: if user_token:
res["input_ids"] = [user_token, *res["input_ids"]] res["input_ids"] = [user_token, *res["input_ids"]]
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
part = part[0] + part[1] if not assistant_token else part[1] part = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistent response, should end with an eos token # this should be the assistent response, should end with an eos token
res = self._tokenize( res = self._tokenize(
part.strip(), add_eos_token=True, strip_bos_token=True part.strip(),
add_eos_token=True,
strip_bos_token=True,
) )
if assistant_token: if assistant_token:
res["input_ids"] = [assistant_token, *res["input_ids"]] res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels # not masked out from labels
labels = copy.deepcopy(res["input_ids"]) labels = copy.deepcopy(res["input_ids"])
else: else:

View File

@@ -2,8 +2,8 @@
import dataclasses import dataclasses
import logging import logging
from enum import auto, Enum from enum import Enum, auto
from typing import List, Optional, Union, Generator from typing import Generator, List, Optional, Union
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
res = self.prompt_no_input.format(instruction=instruction) res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected: if output and reflection and corrected:
label = self.agent_label.format( label = self.agent_label.format(
output=output, reflection=reflection, corrected=corrected output=output,
reflection=reflection,
corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res yield res

View File

@@ -4,9 +4,9 @@ import os
from transformers import ( from transformers import (
TrainerCallback, TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl, TrainerControl,
TrainerState,
TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
**kwargs, **kwargs,
): ):
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
) )
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") peft_model_path = os.path.join(checkpoint_folder, "adapter_model")

View File

@@ -5,38 +5,33 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union
from datasets import ( from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
load_from_disk,
load_dataset,
Dataset,
DatasetDict,
)
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
ShareGPTPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
CompletionPromptTokenizingStrategy,
AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
ShareGPTPromptTokenizingStrategy,
SummarizeTLDRPromptTokenizingStrategy, SummarizeTLDRPromptTokenizingStrategy,
) )
from axolotl.prompters import ( from axolotl.prompters import (
AlpacaPrompter, AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter, GPTeacherPrompter,
JeopardyPrompter,
MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
ShareGPTPrompter, ShareGPTPrompter,
JeopardyPrompter,
CompletionPrompter,
MultipleChoiceExplainPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
MultipleChoiceConcisePrompter,
) )
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
dataset = load_dataset( dataset = load_dataset(
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset["train"]
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token) load_dataset(
d.path,
streaming=True,
use_auth_token=use_auth_token,
)
ds_from_hub = True ds_from_hub = True
except FileNotFoundError: except FileNotFoundError:
pass pass
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
if Path(d.path).exists(): if Path(d.path).exists():
ds = load_dataset( ds = load_dataset(
"json", data_files=d.path, streaming=False, split=None "json",
data_files=d.path,
streaming=False,
split=None,
) )
elif ds_from_hub: elif ds_from_hub:
if d.data_files: if d.data_files:
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
) )
else: else:
ds = load_dataset( ds = load_dataset(
d.path, streaming=False, use_auth_token=use_auth_token d.path,
streaming=False,
use_auth_token=use_auth_token,
) )
else: else:
fp = hf_hub_download( fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files repo_id=d.path,
repo_type="dataset",
filename=d.data_files,
) )
ds = load_dataset("json", data_files=fp, streaming=False, split=None) ds = load_dataset("json", data_files=fp, streaming=False, split=None)
if not ds: if not ds:
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset]: ) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
@@ -353,7 +362,8 @@ def load_prepare_datasets(
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
) )
dataset.push_to_hub( dataset.push_to_hub(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}",
private=True,
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
@@ -365,7 +375,8 @@ def load_prepare_datasets(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
) )
dataset = dataset.shard( dataset = dataset.shard(
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx num_shards=cfg.dataset_shard_num,
index=cfg.dataset_shard_idx,
) )
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)

View File

@@ -5,23 +5,17 @@ import logging
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401 from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from transformers import ( # noqa: F401 from transformers import AutoModelForCausalLM # noqa: F401
AutoModelForCausalLM, from transformers import PreTrainedModel # noqa: F401
AutoTokenizer, from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
PreTrainedModel,
AutoConfig,
BitsAndBytesConfig,
)
try: try:
from transformers import ( from transformers import LlamaForCausalLM
LlamaForCausalLM,
)
except ImportError: except ImportError:
logging.warning( logging.warning(
"This version of transformers does not support Llama. Consider upgrading." "This version of transformers does not support Llama. Consider upgrading."
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING: if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401 from peft import PeftConfig # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
from transformers import PreTrainedTokenizer # noqa: F401 from transformers import PreTrainedTokenizer # noqa: F401
from axolotl.utils.dict import DictDefault # noqa: F401
def load_tokenizer( def load_tokenizer(
base_model_config, base_model_config,
@@ -56,7 +51,10 @@ def load_tokenizer(
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
"LlamaTokenizerFast",
]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
def load_llama_adapter(model, cfg): def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import ( from peft import AdaptionPromptConfig, PeftModel, get_peft_model
AdaptionPromptConfig,
get_peft_model,
PeftModel,
)
peft_config = AdaptionPromptConfig( peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L) adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
def load_lora(model, cfg): def load_lora(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import ( from peft import LoraConfig, PeftModel, get_peft_model
LoraConfig,
get_peft_model,
PeftModel,
)
lora_target_modules = list(cfg.lora_target_modules or []) lora_target_modules = list(cfg.lora_target_modules or [])

View File

@@ -2,6 +2,7 @@
import logging import logging
from termcolor import colored from termcolor import colored

View File

@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.schedulers import InterpolatingLogScheduler
class OneCycleLRSchedulerTrainer(Trainer): class OneCycleLRSchedulerTrainer(Trainer):
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
self.lr_scheduler = None self.lr_scheduler = None
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
): ):
optimizer = self.optimizer if optimizer is None else optimizer optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps) num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
) )
callbacks.append(early_stop_cb) callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0 if cfg.local_rank == 0 and cfg.adapter in [
"lora",
"qlora",
]: # only save in rank 0
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
data_collator_kwargs = { data_collator_kwargs = {

View File

@@ -4,8 +4,8 @@ import unittest
import pytest import pytest
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):