Merge branch 'main' into quadratic-warmup
This commit is contained in:
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
buffer_len = 0
|
||||
|
||||
if example:
|
||||
# FIXME
|
||||
# just going to drop data points that are too long
|
||||
if len(example["input_ids"]) <= self.seq_length:
|
||||
input_ids = example["input_ids"]
|
||||
|
||||
@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
InstructionPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
|
||||
|
||||
class AlpacaConcisePrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Prompter extending the system prompt to ask for concise answers
|
||||
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
|
||||
class AlpacaChatPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
||||
"""
|
||||
|
||||
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.prompt_style = PromptStyle.CHAT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
|
||||
class NoSystemPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Null Prompter with no system prompts
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
turn_format = "{instruction} {input} "
|
||||
turn_no_input_format = "{instruction} "
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
pass
|
||||
|
||||
|
||||
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
|
||||
|
||||
def load_qa(tokenizer, cfg):
|
||||
return AlpacaQAPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
|
||||
|
||||
def load_camel_ai(tokenizer, cfg):
|
||||
return CamelAIPromptTokenizingStrategy(
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
AlpacaChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
||||
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_no_prompt(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
120
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
120
src/axolotl/prompt_strategies/alpaca_w_system.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||
"""
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
|
||||
|
||||
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["instruction"],
|
||||
prompt["input"] if "input" in prompt else "",
|
||||
prompt["output"],
|
||||
prompt["system"],
|
||||
)
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
system,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt_w_system(
|
||||
system,
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class SystemDataPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = system + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = system + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for OpenOrca datasets
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt["question"],
|
||||
"",
|
||||
prompt["response"],
|
||||
prompt["system_prompt"],
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return load_chat(tokenizer, cfg)
|
||||
|
||||
|
||||
def load_instruct(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_chat(tokenizer, cfg):
|
||||
return InstructionWSystemPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
def load_open_orca(tokenizer, cfg):
|
||||
return OpenOrcaPromptTokenizingStrategy(
|
||||
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
def parse_instruction_fields(
|
||||
self, prompt
|
||||
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
input, # pylint: disable=redefined-builtin
|
||||
response,
|
||||
) = self.parse_instruction_fields(prompt)
|
||||
full_prompt = self._build_full_prompt(instruction, input, response)
|
||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
user_prompt = next(
|
||||
iter(
|
||||
self.prompter.build_prompt(
|
||||
instruction,
|
||||
input,
|
||||
)
|
||||
)
|
||||
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
)
|
||||
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
||||
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
||||
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
||||
|
||||
return tokenized_full_prompt
|
||||
return tokenized_prompt
|
||||
|
||||
def _build_full_prompt(
|
||||
self, instruction, input, response # pylint: disable=redefined-builtin
|
||||
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
|
||||
result: Dict[str, List[int]],
|
||||
current_len: int,
|
||||
res: Dict[str, List[int]],
|
||||
labels: list[int],
|
||||
labels: List[int],
|
||||
pad_token_id: Union[int, None] = None,
|
||||
) -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
|
||||
@@ -24,6 +24,8 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
|
||||
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
||||
@@ -32,23 +34,13 @@ class AlpacaPrompter:
|
||||
|
||||
def match_prompt_style(self):
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_no_input_format = (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt
|
||||
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.response_split = "### Response:"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.prompt_input = (
|
||||
self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
)
|
||||
self.prompt_no_input = (
|
||||
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
@@ -59,16 +51,17 @@ class AlpacaPrompter:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = self.prompt_input.format(instruction=instruction, input=input)
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
else:
|
||||
res = self.prompt_no_input.format(instruction=instruction)
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
"""
|
||||
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning."
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
system_no_input_prompt = (
|
||||
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
||||
)
|
||||
|
||||
|
||||
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
||||
Prompter for multiple choice concise
|
||||
"""
|
||||
|
||||
prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
|
||||
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
||||
Prompter for summarize TLDR
|
||||
"""
|
||||
|
||||
prompt_no_input = (
|
||||
"USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
)
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
@@ -128,9 +132,6 @@ class CompletionPrompter:
|
||||
) -> Generator[str, None, None]:
|
||||
yield instruction
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.strip()
|
||||
|
||||
|
||||
class GPTeacherPrompter(AlpacaPrompter):
|
||||
"""
|
||||
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
def get_response(self, output: str) -> str:
|
||||
return output.split(self.response_split)[1].strip()
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
# def match_prompt_style(self):
|
||||
# if self.prompt_style == PromptStyle.chat.value:
|
||||
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
||||
# self.response_split = "ASSISTANT:"
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
# ignore the system prompt if provided
|
||||
if source[0]["from"] == "system":
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
||||
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
||||
kwargs["model"].save_pretrained(peft_model_path)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods
|
||||
"""Callback to save the BetterTransformer wrapped model"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
# Save
|
||||
if (
|
||||
args.save_strategy == IntervalStrategy.STEPS
|
||||
and args.save_steps > 0
|
||||
and state.global_step % args.save_steps == 0
|
||||
):
|
||||
control.should_save = True
|
||||
|
||||
if control.should_save:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
|
||||
model = BetterTransformer.reverse(kwargs["model"])
|
||||
model.save_pretrained(checkpoint_folder)
|
||||
# FIXME - need to cleanup old checkpoints
|
||||
|
||||
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
||||
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
||||
control.should_save = False
|
||||
return control
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Module containing data utilities"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -101,13 +102,26 @@ def load_tokenized_prepared_datasets(
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
if Path(d.path).exists():
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
if d.data_files:
|
||||
ds = load_dataset(
|
||||
@@ -394,8 +408,127 @@ def load_prepare_datasets(
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
if cfg.val_set_size:
|
||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
else:
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(tokenizer, max_tokens, examples):
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
# Convert to PyTorch tensors
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
new_input_ids = []
|
||||
new_attention_mask = []
|
||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||
for i, _ in enumerate(input_ids):
|
||||
input_ids[i] = torch.cat(
|
||||
(
|
||||
input_ids[i],
|
||||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||
|
||||
# Concatenate tokens so that their lengths are less than max_tokens
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
for ids, mask in zip(input_ids, attention_mask):
|
||||
if buffer_input_ids.numel() == max_tokens:
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
else:
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||
|
||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||
|
||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
||||
buffer_input_ids = torch.cat(
|
||||
(
|
||||
buffer_input_ids,
|
||||
torch.full(
|
||||
(max_tokens - buffer_input_ids.numel(),),
|
||||
tokenizer.pad_token_id,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
buffer_attention_mask = torch.cat(
|
||||
(
|
||||
buffer_attention_mask,
|
||||
torch.full(
|
||||
(max_tokens - buffer_attention_mask.numel(),),
|
||||
0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
new_input_ids.append(buffer_input_ids)
|
||||
new_attention_mask.append(buffer_attention_mask)
|
||||
|
||||
ret = {
|
||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||
"labels": [seq.tolist() for seq in new_input_ids],
|
||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||
}
|
||||
|
||||
logging.debug(len(ret["input_ids"]))
|
||||
return ret
|
||||
|
||||
|
||||
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
dataset = load_dataset(path, streaming=True, split="train")
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
# TODO dynamically figure out which columns/features to remove
|
||||
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
||||
return dataset
|
||||
|
||||
@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PreTrainedModel # noqa: F401
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import ( # noqa: F401
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
@@ -32,15 +34,20 @@ def load_tokenizer(
|
||||
tokenizer_type,
|
||||
cfg,
|
||||
):
|
||||
use_fast = True # this is the default
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if tokenizer_type:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
)
|
||||
|
||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -70,7 +77,7 @@ def load_tokenizer(
|
||||
def load_model(
|
||||
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
||||
):
|
||||
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
"""
|
||||
Load a model from a base model and a model type.
|
||||
"""
|
||||
@@ -121,9 +128,9 @@ def load_model(
|
||||
logging.info("patching with xpos rope")
|
||||
replace_llama_rope_with_xpos_rope()
|
||||
|
||||
if cfg.bf16:
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype = torch.bfloat16
|
||||
elif cfg.load_in_8bit or cfg.fp16:
|
||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
@@ -195,7 +202,7 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif cfg.is_llama_derived_model:
|
||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config = LlamaConfig.from_pretrained(base_model_config)
|
||||
@@ -234,7 +241,7 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type:
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
@@ -251,11 +258,16 @@ def load_model(
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
|
||||
if (
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
):
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
@@ -278,6 +290,7 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -287,6 +300,16 @@ def load_model(
|
||||
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings
|
||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||
):
|
||||
logging.warning(
|
||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if not cfg.gptq and (
|
||||
(cfg.adapter == "lora" and load_in_8bit)
|
||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||
@@ -332,6 +355,9 @@ def load_model(
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
|
||||
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
||||
|
||||
logging.info(" ".join(colored_tokens))
|
||||
logging.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
@@ -17,7 +17,10 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -166,6 +169,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
# TODO search Path("./") for one
|
||||
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
||||
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
if cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
|
||||
training_args = AxolotlTrainingArguments(
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
@@ -282,6 +298,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
]: # only save in rank 0
|
||||
callbacks.append(SavePeftModelCallback)
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
@@ -62,7 +64,47 @@ def validate_config(cfg):
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
logging.warning(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
)
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
logging.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".")[0]) < 2:
|
||||
logging.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
logging.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adamw w bf16
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
Reference in New Issue
Block a user