various bugfixes

This commit is contained in:
Wing Lian
2023-04-14 21:37:07 -04:00
parent 45f77dd51e
commit 80b2ed29d8
5 changed files with 33 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
base_model: decapoda-research/llama-65b-hf base_model: huggyllama/llama-7b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: true load_in_8bit: true
@@ -33,8 +33,8 @@ num_epochs: 5
learning_rate: 0.00003 learning_rate: 0.00003
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: True bf16: true
tf32: True tf32: true
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
deepspeed: deepspeed:

View File

@@ -10,3 +10,6 @@ accelerate
sentencepiece sentencepiece
wandb wandb
flash-attn flash-attn
deepspeed
einops

View File

@@ -11,7 +11,7 @@ import torch
import transformers import transformers
import yaml import yaml
from attrdict import AttrDefault from attrdict import AttrDefault
from datasets import load_dataset, IterableDataset, Dataset from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
from peft import ( from peft import (
LoraConfig, LoraConfig,
get_peft_model, get_peft_model,
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
if adapter != "lora": if adapter != "lora":
raise NotImplementedError(f"{adapter} peft adapter not available") raise NotImplementedError(f"{adapter} peft adapter not available")
if "llama" in base_model: if "llama" in base_model:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn if cfg.device not in ["mps", "cpu"]:
replace_llama_attn_with_flash_attn() from axolotl.flash_attn import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
try: try:
if "llama" in base_model: if "llama" in base_model:
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
except: except:
tokenizer = AutoTokenizer.from_pretrained(base_model) tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.__class__.__name__ == "LlamaTokenizer": if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -255,8 +256,9 @@ def train(
return return
datasets = [] datasets = []
if len(cfg.datasets) == 1 and cfg.datasets[0].type == "arrow": if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
dataset = load_dataset(cfg.datasets[0].path, split="train") # assumption that we are loading a previously saved/cached dataset
dataset = load_from_disk(cfg.datasets)
else: else:
for d in cfg.datasets: for d in cfg.datasets:
ds: IterableDataset = load_dataset( ds: IterableDataset = load_dataset(
@@ -288,7 +290,6 @@ def train(
[_ for _ in constant_len_dataset] [_ for _ in constant_len_dataset]
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
dataset.save_to_disk("data/last_run") dataset.save_to_disk("data/last_run")
print(dataset)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]

View File

@@ -23,6 +23,7 @@ install_requires =
sentencepiece sentencepiece
wandb wandb
flash-attn flash-attn
einops
[options.packages.find] [options.packages.find]
where = src where = src

View File

@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
buffer_len = 0 buffer_len = 0
if example: if example:
input_ids = example["input_ids"] # just going to drop data points that are too long
attention_mask = example["attention_mask"] if len(example["input_ids"]) <= self.seq_length:
labels = example["labels"] input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token: if add_concat_token:
input_ids.append(self.concat_token_id) input_ids.append(self.concat_token_id)
attention_mask.append(1) attention_mask.append(1)
labels.append(self.concat_token_id) labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long) input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
attention_mask_with_concat = torch.tensor( attention_mask_with_concat = torch.tensor(
attention_mask, dtype=torch.long attention_mask, dtype=torch.long
) )
labels_with_concat = torch.tensor(labels, dtype=torch.long) labels_with_concat = torch.tensor(labels, dtype=torch.long)
buffer["input_ids"].append(input_ids_with_concat) buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat) buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat) buffer["labels"].append(labels_with_concat)
buffer_len += len(input_ids) buffer_len += len(input_ids)