Merge pull request #26 from OpenAccess-AI-Collective/mpt-triton

Mpt triton
This commit is contained in:
Wing Lian
2023-05-10 16:02:05 -04:00
committed by GitHub
12 changed files with 143 additions and 75 deletions

View File

@@ -191,7 +191,9 @@ def train(
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
),
tokenizer,
)
@@ -218,17 +220,20 @@ def train(
logging.info("Starting trainer...")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
trainer.save_pretrained(cfg.output_dir)

View File

@@ -10,22 +10,22 @@ with open("./requirements.txt", "r") as requirements_file:
install_requires.append(r)
setup(
name='axolotl',
version='0.1',
name="axolotl",
version="0.1",
description="You know you're going to axolotl questions",
package_dir={'': 'src'},
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
extras_require={
'int4': [
"int4": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
'int4_triton': [
"int4_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
'extras': [
'flash-attn',
'deepspeed',
]
"extras": [
"flash-attn",
"deepspeed",
],
},
)

View File

@@ -31,6 +31,7 @@ class TokenizedPromptDataset(IterableDataset):
except InvalidDataException:
pass
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""
@@ -40,6 +41,7 @@ class ConstantLengthDataset(IterableDataset):
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
def __init__(
self,
tokenizer,
@@ -93,14 +95,19 @@ class ConstantLengthDataset(IterableDataset):
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
if (
labels.size() == input_ids.size()
and attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
else:
logging.warning("dropping batch due to tensor size mismatch")
logging.warning(
"dropping batch due to tensor size mismatch"
)
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0
@@ -116,11 +123,15 @@ class ConstantLengthDataset(IterableDataset):
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=self.tokens_dtype)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
)
labels_with_concat = torch.tensor(labels, dtype=self.tokens_dtype)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)

View File

@@ -126,10 +126,8 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str):
return (
prompt["text"]
)
def parse_instruction_fields(self, prompt) -> str:
return prompt["text"]
def tokenize_prompt(self, prompt):
instruction = self.parse_instruction_fields(prompt)
@@ -139,9 +137,7 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return tokenized_full_prompt
def _build_full_prompt(self, instruction):
return self.prompter.build_prompt(
instruction
)
return self.prompter.build_prompt(instruction)
class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
@@ -149,8 +145,16 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError
def tokenize_prompt(self, prompt):
instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
(
instruction,
input,
output,
reflection,
corrected,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(
instruction, input, output, reflection, corrected
)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(

View File

@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
class CompletionPrompter(AlpacaPrompter):
def build_prompt(
self,
instruction: str
) -> str:
def build_prompt(self, instruction: str) -> str:
return instruction
def get_response(self, output: str) -> str:
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
else:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
label = self.agent_label.format(
output=output, reflection=reflection, corrected=corrected
)
res = f"{res}{label}"
return res
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
round_len = (
len(tokenizer(rou)["input_ids"]) - 1
) # -1 ignores the bos_token generated for this
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
instruction_len = (
len(tokenizer(parts[0].strip())["input_ids"]) - 1
) # -1 ignores the bos_token generated for this
target[cur_len : cur_len + instruction_len] = [
IGNORE_TOKEN_ID
] * instruction_len
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
break
# Fix: Truncate the target to have the same length as input_ids
target = target[:len(tokenized_result["input_ids"])]
target = target[: len(tokenized_result["input_ids"])]
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [

View File

@@ -1,8 +1,15 @@
import os
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers import (
Seq2SeqTrainer,
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
@@ -11,7 +18,9 @@ class SavePeftModelCallback(TrainerCallback):
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)

View File

@@ -2,7 +2,13 @@ import logging
from hashlib import md5
from pathlib import Path
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
from datasets import (
load_from_disk,
load_dataset,
IterableDataset,
Dataset,
concatenate_datasets,
)
from huggingface_hub import hf_hub_download
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
@@ -75,7 +81,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
else:
ds = load_dataset(d.path, streaming=True)
else:
fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
fp = hf_hub_download(
repo_id=d.path, repo_type="dataset", filename=d.data_files
)
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
if not ds:
raise Exception("unhandled dataset load")
@@ -140,7 +148,9 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
samples = samples + [i for i in d]
dataset = Dataset.from_list(samples).shuffle(seed=42)
if cfg.local_rank == 0:
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
logging.info(
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
if cfg.max_packed_sequence_len is not None:
@@ -153,12 +163,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
logging.info(
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
)
dataset = dataset.shard(
num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
)
dataset = dataset.train_test_split(
test_size=cfg.val_set_size, shuffle=False
)
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]

View File

@@ -9,14 +9,18 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
AutoConfig,
)
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -40,7 +44,9 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
)
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
@@ -49,11 +55,16 @@ def load_model(
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
elif is_llama_derived_model and cfg.xformers_attention:
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
logging.info("patching with xformers attention")
hijack_llama_attention()
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
torch_dtype = (
torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
)
try:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -74,8 +85,12 @@ def load_model(
try:
snapshot_download_kwargs = {}
if cfg.base_model_ignore_patterns:
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
snapshot_download_kwargs[
"ignore_patterns"
] = cfg.base_model_ignore_patterns
cache_model_path = Path(
snapshot_download(base_model, **snapshot_download_kwargs)
)
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
@@ -116,8 +131,13 @@ def load_model(
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=True if cfg.trust_remote_code is True else False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,

View File

@@ -26,7 +26,10 @@ class InterpolatingLogScheduler(LRScheduler):
if self.last_epoch <= 0:
lrs = [self.min_lr for base_lr in self.base_lrs]
elif self.last_epoch < self.num_steps:
lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs]
lrs = [
self.min_lr * (self.q ** (self.last_epoch - 1))
for base_lr in self.base_lrs
]
else:
lrs = [self.max_lr for base_lr in self.base_lrs]

View File

@@ -1,6 +1,7 @@
from termcolor import colored
import logging
def check_dataset_labels(dataset, tokenizer):
# the dataset is already shuffled, so let's just check the first 5 elements
for idx in range(5):
@@ -11,7 +12,7 @@ def check_example_labels(example, tokenizer):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
attention_mask =example["attention_mask"]
attention_mask = example["attention_mask"]
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -21,9 +22,7 @@ def check_example_labels(example, tokenizer):
):
decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not
color = (
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
)
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + colored(
f"({label_id}, {mask}, {input_id})", "white"
)

View File

@@ -30,16 +30,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
save_steps = (
cfg.save_steps
if cfg.save_steps is not None
else min(int(0.05 * total_num_steps), 200)
)
eval_steps = (
cfg.eval_steps
if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
else save_steps
)
save_steps = cfg.save_steps
eval_steps = cfg.eval_steps
training_arguments_kwargs = {}
if cfg.bf16 == "full":
@@ -86,26 +78,33 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
else cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps",
save_strategy="steps" if save_steps else "epoch",
eval_steps=eval_steps if cfg.val_set_size > 0 else None,
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True
if cfg.val_set_size > 0 and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
if cfg.val_set_size > 0
and save_steps is not None
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
else False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
lr_scheduler_type=cfg.lr_scheduler
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
**training_arguments_kwargs,
)
@@ -158,6 +157,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
div_factor=10,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
@@ -191,7 +191,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
data_collator_kwargs["pad_to_multiple_of"] = 8
callbacks = []
if cfg.adapter == 'lora':
if cfg.adapter == "lora":
callbacks.append(SavePeftModelCallback)
trainer = transformers.Trainer(

View File

@@ -2,7 +2,9 @@ import os
def setup_wandb_env_vars(cfg):
if cfg.wandb_project and len(cfg.wandb_project) > 0:
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project
cfg.use_wandb = True
if cfg.wandb_watch and len(cfg.wandb_watch) > 0: