Apply isort then black
This commit is contained in:
@@ -2,23 +2,20 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import fire
|
||||
|
||||
|
||||
from axolotl.convert import (
|
||||
FileReader,
|
||||
StdoutWriter,
|
||||
FileWriter,
|
||||
JsonlSerializer,
|
||||
JsonParser,
|
||||
JsonToJsonlConverter,
|
||||
StdoutWriter,
|
||||
)
|
||||
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
|
||||
@@ -7,20 +7,20 @@ import random
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from axolotl.utils.data import load_prepare_datasets
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -242,7 +242,10 @@ def train(
|
||||
if cfg.local_rank == 0:
|
||||
signal.signal(
|
||||
signal.SIGINT,
|
||||
lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)),
|
||||
lambda signal, frame: (
|
||||
model.save_pretrained(cfg.output_dir),
|
||||
sys.exit(0),
|
||||
),
|
||||
)
|
||||
|
||||
logging.info("Starting trainer...")
|
||||
@@ -255,7 +258,8 @@ def train(
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
resume_from_checkpoint = sorted_paths[-1]
|
||||
logging.info(
|
||||
|
||||
2
setup.py
2
setup.py
@@ -1,6 +1,6 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
install_requires = []
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import IterableDataset
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
||||
|
||||
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
# lets use the concept of middlewares to wrap each dataset, for example
|
||||
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
|
||||
logging.warning(
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
}
|
||||
buffer_len = 0
|
||||
|
||||
if example:
|
||||
|
||||
@@ -5,14 +5,11 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
def forward(
|
||||
@@ -75,7 +72,11 @@ def forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
@@ -88,25 +89,44 @@ def forward(
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
x_unpad,
|
||||
"nnz (three h d) -> nnz three h d",
|
||||
three=3,
|
||||
h=nheads,
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
x_unpad,
|
||||
cu_q_lens,
|
||||
max_s,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
return (
|
||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
InstructionPromptTokenizingStrategy,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
||||
|
||||
from typing import Tuple, Union, Generator
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
|
||||
|
||||
@@ -61,10 +62,14 @@ Answer: {answer}
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -97,10 +102,14 @@ Evaluation:
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
scores = yaml.dump(
|
||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["scores"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
critiques = yaml.dump(
|
||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
||||
prompt["critiques"],
|
||||
default_flow_style=False,
|
||||
Dumper=yaml.Dumper,
|
||||
)
|
||||
evaluation = scores + critiques
|
||||
question = prompt["instruction"]
|
||||
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
|
||||
|
||||
def load_answer(tokenizer, cfg):
|
||||
return CreativeAnsweringPromptTokenizingStrategy(
|
||||
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeAnswerPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_critique(tokenizer, cfg):
|
||||
return CreativeCritiquePromptTokenizingStrategy(
|
||||
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeCritiquePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_revise(tokenizer, cfg):
|
||||
return CreativeRevisePromptTokenizingStrategy(
|
||||
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
||||
CreativeRevisePrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
part = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistent response, should end with an eos token
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=True, strip_bos_token=True
|
||||
part.strip(),
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import auto, Enum
|
||||
from typing import List, Optional, Union, Generator
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
if output and reflection and corrected:
|
||||
label = self.agent_label.format(
|
||||
output=output, reflection=reflection, corrected=corrected
|
||||
output=output,
|
||||
reflection=reflection,
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
@@ -4,9 +4,9 @@ import os
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
TrainerState,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
**kwargs,
|
||||
):
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||
|
||||
@@ -5,38 +5,33 @@ from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
load_from_disk,
|
||||
load_dataset,
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
)
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
CompletionPromptTokenizingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
CompletionPrompter,
|
||||
GPTeacherPrompter,
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter,
|
||||
JeopardyPrompter,
|
||||
CompletionPrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
)
|
||||
|
||||
|
||||
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except Exception: # pylint: disable=broad-except
|
||||
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
|
||||
load_dataset(
|
||||
d.path,
|
||||
streaming=True,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
|
||||
# prefer local dataset, even if hub exists
|
||||
if Path(d.path).exists():
|
||||
ds = load_dataset(
|
||||
"json", data_files=d.path, streaming=False, split=None
|
||||
"json",
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif ds_from_hub:
|
||||
if d.data_files:
|
||||
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
else:
|
||||
ds = load_dataset(
|
||||
d.path, streaming=False, use_auth_token=use_auth_token
|
||||
d.path,
|
||||
streaming=False,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=d.data_files,
|
||||
)
|
||||
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||
if not ds:
|
||||
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
@@ -353,7 +362,8 @@ def load_prepare_datasets(
|
||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
@@ -365,7 +375,8 @@ def load_prepare_datasets(
|
||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||
)
|
||||
dataset = dataset.shard(
|
||||
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
||||
num_shards=cfg.dataset_shard_num,
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
|
||||
@@ -5,23 +5,17 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
||||
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import ( # noqa: F401
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
AutoConfig,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM # noqa: F401
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
from transformers import LlamaForCausalLM
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from peft import PeftConfig # noqa: F401
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
from transformers import PreTrainedTokenizer # noqa: F401
|
||||
|
||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||
|
||||
|
||||
def load_tokenizer(
|
||||
base_model_config,
|
||||
@@ -56,7 +51,10 @@ def load_tokenizer(
|
||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
if tokenizer.__class__.__name__ in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import (
|
||||
AdaptionPromptConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
|
||||
def load_lora(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
PeftModel,
|
||||
)
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(Trainer):
|
||||
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
||||
if cfg.local_rank == 0 and cfg.adapter in [
|
||||
"lora",
|
||||
"qlora",
|
||||
]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
|
||||
@@ -4,8 +4,8 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.validation import validate_config
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user