Apply isort then black
This commit is contained in:
@@ -2,23 +2,20 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
|
|
||||||
from axolotl.convert import (
|
from axolotl.convert import (
|
||||||
FileReader,
|
FileReader,
|
||||||
StdoutWriter,
|
|
||||||
FileWriter,
|
FileWriter,
|
||||||
JsonlSerializer,
|
JsonlSerializer,
|
||||||
JsonParser,
|
JsonParser,
|
||||||
JsonToJsonlConverter,
|
JsonToJsonlConverter,
|
||||||
|
StdoutWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
|||||||
@@ -7,20 +7,20 @@ import random
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict, Any, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.utils.data import load_prepare_datasets
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from axolotl.utils.data import load_prepare_datasets
|
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -242,7 +242,10 @@ def train(
|
|||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
signal.signal(
|
signal.signal(
|
||||||
signal.SIGINT,
|
signal.SIGINT,
|
||||||
lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)),
|
lambda signal, frame: (
|
||||||
|
model.save_pretrained(cfg.output_dir),
|
||||||
|
sys.exit(0),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Starting trainer...")
|
logging.info("Starting trainer...")
|
||||||
@@ -255,7 +258,8 @@ def train(
|
|||||||
]
|
]
|
||||||
if len(possible_checkpoints) > 0:
|
if len(possible_checkpoints) > 0:
|
||||||
sorted_paths = sorted(
|
sorted_paths = sorted(
|
||||||
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
|
possible_checkpoints,
|
||||||
|
key=lambda path: int(path.split("-")[-1]),
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = sorted_paths[-1]
|
resume_from_checkpoint = sorted_paths[-1]
|
||||||
logging.info(
|
logging.info(
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -1,6 +1,6 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
from setuptools import setup, find_packages
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
install_requires = []
|
install_requires = []
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import IterableDataset
|
from datasets import IterableDataset
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
|
|
||||||
|
|
||||||
|
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
||||||
|
|
||||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||||
# lets use the concept of middlewares to wrap each dataset, for example
|
# lets use the concept of middlewares to wrap each dataset, for example
|
||||||
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||||
)
|
)
|
||||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
buffer = {
|
||||||
|
"input_ids": [],
|
||||||
|
"attention_mask": [],
|
||||||
|
"labels": [],
|
||||||
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
|
|
||||||
if example:
|
if example:
|
||||||
|
|||||||
@@ -5,14 +5,11 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||||
from flash_attn.bert_padding import unpad_input, pad_input
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -75,7 +72,11 @@ def forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
max_s = q_len
|
max_s = q_len
|
||||||
cu_q_lens = torch.arange(
|
cu_q_lens = torch.arange(
|
||||||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
|
0,
|
||||||
|
(bsz + 1) * q_len,
|
||||||
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=qkv.device,
|
||||||
)
|
)
|
||||||
output = flash_attn_unpadded_qkvpacked_func(
|
output = flash_attn_unpadded_qkvpacked_func(
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
@@ -88,25 +89,44 @@ def forward(
|
|||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||||
x_unpad = rearrange(
|
x_unpad = rearrange(
|
||||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
x_unpad,
|
||||||
|
"nnz (three h d) -> nnz three h d",
|
||||||
|
three=3,
|
||||||
|
h=nheads,
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
x_unpad,
|
||||||
|
cu_q_lens,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
)
|
)
|
||||||
output = rearrange(
|
output = rearrange(
|
||||||
pad_input(
|
pad_input(
|
||||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
|
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||||
|
indices,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
),
|
),
|
||||||
"b s (h d) -> b s h d",
|
"b s (h d) -> b s h d",
|
||||||
h=nheads,
|
h=nheads,
|
||||||
)
|
)
|
||||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
return (
|
||||||
|
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
self,
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# [bsz, seq_len]
|
# [bsz, seq_len]
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
InstructionPromptTokenizingStrategy,
|
InstructionPromptTokenizingStrategy,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
"""Module loading the CreativePromptTokenizingStrategy and similar classes"""
|
||||||
|
|
||||||
from typing import Tuple, Union, Generator
|
from typing import Generator, Tuple, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||||
|
|
||||||
|
|
||||||
@@ -61,10 +62,14 @@ Answer: {answer}
|
|||||||
|
|
||||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
scores = yaml.dump(
|
scores = yaml.dump(
|
||||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
prompt["scores"],
|
||||||
|
default_flow_style=False,
|
||||||
|
Dumper=yaml.Dumper,
|
||||||
)
|
)
|
||||||
critiques = yaml.dump(
|
critiques = yaml.dump(
|
||||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
prompt["critiques"],
|
||||||
|
default_flow_style=False,
|
||||||
|
Dumper=yaml.Dumper,
|
||||||
)
|
)
|
||||||
evaluation = scores + critiques
|
evaluation = scores + critiques
|
||||||
question = prompt["instruction"]
|
question = prompt["instruction"]
|
||||||
@@ -97,10 +102,14 @@ Evaluation:
|
|||||||
|
|
||||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
scores = yaml.dump(
|
scores = yaml.dump(
|
||||||
prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
|
prompt["scores"],
|
||||||
|
default_flow_style=False,
|
||||||
|
Dumper=yaml.Dumper,
|
||||||
)
|
)
|
||||||
critiques = yaml.dump(
|
critiques = yaml.dump(
|
||||||
prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
|
prompt["critiques"],
|
||||||
|
default_flow_style=False,
|
||||||
|
Dumper=yaml.Dumper,
|
||||||
)
|
)
|
||||||
evaluation = scores + critiques
|
evaluation = scores + critiques
|
||||||
question = prompt["instruction"]
|
question = prompt["instruction"]
|
||||||
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
|
|||||||
|
|
||||||
def load_answer(tokenizer, cfg):
|
def load_answer(tokenizer, cfg):
|
||||||
return CreativeAnsweringPromptTokenizingStrategy(
|
return CreativeAnsweringPromptTokenizingStrategy(
|
||||||
CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
CreativeAnswerPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_critique(tokenizer, cfg):
|
def load_critique(tokenizer, cfg):
|
||||||
return CreativeCritiquePromptTokenizingStrategy(
|
return CreativeCritiquePromptTokenizingStrategy(
|
||||||
CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
CreativeCritiquePrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_revise(tokenizer, cfg):
|
def load_revise(tokenizer, cfg):
|
||||||
return CreativeRevisePromptTokenizingStrategy(
|
return CreativeRevisePromptTokenizingStrategy(
|
||||||
CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
CreativeRevisePrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
part = part[0] + part[1] if not user_token else part[1]
|
part = part[0] + part[1] if not user_token else part[1]
|
||||||
# this is still the user query, we should
|
# this is still the user query, we should
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
part.strip(), add_eos_token=False, strip_bos_token=True
|
part.strip(),
|
||||||
|
add_eos_token=False,
|
||||||
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
if user_token:
|
if user_token:
|
||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
part = part[0] + part[1] if not assistant_token else part[1]
|
part = part[0] + part[1] if not assistant_token else part[1]
|
||||||
# this should be the assistent response, should end with an eos token
|
# this should be the assistent response, should end with an eos token
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
part.strip(), add_eos_token=True, strip_bos_token=True
|
part.strip(),
|
||||||
|
add_eos_token=True,
|
||||||
|
strip_bos_token=True,
|
||||||
)
|
)
|
||||||
if assistant_token:
|
if assistant_token:
|
||||||
res["input_ids"] = [assistant_token, *res["input_ids"]]
|
res["input_ids"] = [
|
||||||
|
assistant_token,
|
||||||
|
*res["input_ids"],
|
||||||
|
]
|
||||||
# not masked out from labels
|
# not masked out from labels
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from enum import auto, Enum
|
from enum import Enum, auto
|
||||||
from typing import List, Optional, Union, Generator
|
from typing import Generator, List, Optional, Union
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
|
|||||||
res = self.prompt_no_input.format(instruction=instruction)
|
res = self.prompt_no_input.format(instruction=instruction)
|
||||||
if output and reflection and corrected:
|
if output and reflection and corrected:
|
||||||
label = self.agent_label.format(
|
label = self.agent_label.format(
|
||||||
output=output, reflection=reflection, corrected=corrected
|
output=output,
|
||||||
|
reflection=reflection,
|
||||||
|
corrected=corrected,
|
||||||
)
|
)
|
||||||
res = f"{res}{label}"
|
res = f"{res}{label}"
|
||||||
yield res
|
yield res
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ import os
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
|
||||||
TrainerState,
|
|
||||||
TrainerControl,
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
checkpoint_folder = os.path.join(
|
checkpoint_folder = os.path.join(
|
||||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
args.output_dir,
|
||||||
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
)
|
)
|
||||||
|
|
||||||
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
||||||
|
|||||||
@@ -5,38 +5,33 @@ from hashlib import md5
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from datasets import (
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
load_from_disk,
|
|
||||||
load_dataset,
|
|
||||||
Dataset,
|
|
||||||
DatasetDict,
|
|
||||||
)
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
|
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
|
||||||
GPTeacherPromptTokenizingStrategy,
|
|
||||||
OpenAssistantPromptTokenizingStrategy,
|
|
||||||
AlpacaReflectionPTStrategy,
|
|
||||||
ShareGPTPromptTokenizingStrategy,
|
|
||||||
JeopardyPromptTokenizingStrategy,
|
|
||||||
CompletionPromptTokenizingStrategy,
|
|
||||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||||
|
AlpacaPromptTokenizingStrategy,
|
||||||
|
AlpacaReflectionPTStrategy,
|
||||||
|
CompletionPromptTokenizingStrategy,
|
||||||
|
GPTeacherPromptTokenizingStrategy,
|
||||||
|
JeopardyPromptTokenizingStrategy,
|
||||||
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
|
ShareGPTPromptTokenizingStrategy,
|
||||||
SummarizeTLDRPromptTokenizingStrategy,
|
SummarizeTLDRPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import (
|
from axolotl.prompters import (
|
||||||
AlpacaPrompter,
|
AlpacaPrompter,
|
||||||
|
CompletionPrompter,
|
||||||
GPTeacherPrompter,
|
GPTeacherPrompter,
|
||||||
|
JeopardyPrompter,
|
||||||
|
MultipleChoiceConcisePrompter,
|
||||||
|
MultipleChoiceExplainPrompter,
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
ShareGPTPrompter,
|
ShareGPTPrompter,
|
||||||
JeopardyPrompter,
|
|
||||||
CompletionPrompter,
|
|
||||||
MultipleChoiceExplainPrompter,
|
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
MultipleChoiceConcisePrompter,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
|
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
dataset = dataset["train"]
|
dataset = dataset["train"]
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds: Union[Dataset, DatasetDict] = None
|
ds: Union[Dataset, DatasetDict] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
|
load_dataset(
|
||||||
|
d.path,
|
||||||
|
streaming=True,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
|
|||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
if Path(d.path).exists():
|
if Path(d.path).exists():
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json", data_files=d.path, streaming=False, split=None
|
"json",
|
||||||
|
data_files=d.path,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
)
|
)
|
||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
if d.data_files:
|
if d.data_files:
|
||||||
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path, streaming=False, use_auth_token=use_auth_token
|
d.path,
|
||||||
|
streaming=False,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
repo_id=d.path, repo_type="dataset", filename=d.data_files
|
repo_id=d.path,
|
||||||
|
repo_type="dataset",
|
||||||
|
filename=d.data_files,
|
||||||
)
|
)
|
||||||
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
ds = load_dataset("json", data_files=fp, streaming=False, split=None)
|
||||||
if not ds:
|
if not ds:
|
||||||
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset]:
|
) -> Tuple[Dataset, Dataset]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
@@ -353,7 +362,8 @@ def load_prepare_datasets(
|
|||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||||
|
private=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_tokenized_prepared_datasets(
|
dataset = load_tokenized_prepared_datasets(
|
||||||
@@ -365,7 +375,8 @@ def load_prepare_datasets(
|
|||||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||||
)
|
)
|
||||||
dataset = dataset.shard(
|
dataset = dataset.shard(
|
||||||
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
|
num_shards=cfg.dataset_shard_num,
|
||||||
|
index=cfg.dataset_shard_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||||
|
|||||||
@@ -5,23 +5,17 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
|
from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import ( # noqa: F401
|
from transformers import AutoModelForCausalLM # noqa: F401
|
||||||
AutoModelForCausalLM,
|
from transformers import PreTrainedModel # noqa: F401
|
||||||
AutoTokenizer,
|
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
|
||||||
PreTrainedModel,
|
|
||||||
AutoConfig,
|
|
||||||
BitsAndBytesConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import (
|
from transformers import LlamaForCausalLM
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"This version of transformers does not support Llama. Consider upgrading."
|
"This version of transformers does not support Llama. Consider upgrading."
|
||||||
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from peft import PeftConfig # noqa: F401
|
from peft import PeftConfig # noqa: F401
|
||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
|
||||||
from transformers import PreTrainedTokenizer # noqa: F401
|
from transformers import PreTrainedTokenizer # noqa: F401
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
def load_tokenizer(
|
||||||
base_model_config,
|
base_model_config,
|
||||||
@@ -56,7 +51,10 @@ def load_tokenizer(
|
|||||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
if tokenizer.__class__.__name__ in [
|
||||||
|
"LlamaTokenizer",
|
||||||
|
"LlamaTokenizerFast",
|
||||||
|
]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
|
|||||||
|
|
||||||
def load_llama_adapter(model, cfg):
|
def load_llama_adapter(model, cfg):
|
||||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
from peft import (
|
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||||
AdaptionPromptConfig,
|
|
||||||
get_peft_model,
|
|
||||||
PeftModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
peft_config = AdaptionPromptConfig(
|
peft_config = AdaptionPromptConfig(
|
||||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||||
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
|
|||||||
def load_lora(model, cfg):
|
def load_lora(model, cfg):
|
||||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||||
|
|
||||||
from peft import (
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
LoraConfig,
|
|
||||||
get_peft_model,
|
|
||||||
PeftModel,
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from transformers import EarlyStoppingCallback, Trainer
|
from transformers import EarlyStoppingCallback, Trainer
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
|
||||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||||
|
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(Trainer):
|
class OneCycleLRSchedulerTrainer(Trainer):
|
||||||
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
|
|||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
):
|
):
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
)
|
)
|
||||||
callbacks.append(early_stop_cb)
|
callbacks.append(early_stop_cb)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
|
if cfg.local_rank == 0 and cfg.adapter in [
|
||||||
|
"lora",
|
||||||
|
"qlora",
|
||||||
|
]: # only save in rank 0
|
||||||
callbacks.append(SavePeftModelCallback)
|
callbacks.append(SavePeftModelCallback)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ import unittest
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.validation import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.validation import validate_config
|
||||||
|
|
||||||
|
|
||||||
class ValidationTest(unittest.TestCase):
|
class ValidationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user