various bugfixes
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
base_model: decapoda-research/llama-65b-hf
|
base_model: huggyllama/llama-7b
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
@@ -33,8 +33,8 @@ num_epochs: 5
|
|||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: True
|
bf16: true
|
||||||
tf32: True
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -10,3 +10,6 @@ accelerate
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
flash-attn
|
flash-attn
|
||||||
|
deepspeed
|
||||||
|
einops
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
from attrdict import AttrDefault
|
from attrdict import AttrDefault
|
||||||
from datasets import load_dataset, IterableDataset, Dataset
|
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
||||||
from peft import (
|
from peft import (
|
||||||
LoraConfig,
|
LoraConfig,
|
||||||
get_peft_model,
|
get_peft_model,
|
||||||
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
if adapter != "lora":
|
if adapter != "lora":
|
||||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||||
if "llama" in base_model:
|
if "llama" in base_model:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
if cfg.device not in ["mps", "cpu"]:
|
||||||
replace_llama_attn_with_flash_attn()
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "llama" in base_model:
|
if "llama" in base_model:
|
||||||
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
|||||||
except:
|
except:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "LlamaTokenizer":
|
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||||
@@ -255,8 +256,9 @@ def train(
|
|||||||
return
|
return
|
||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
if len(cfg.datasets) == 1 and cfg.datasets[0].type == "arrow":
|
if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
|
||||||
dataset = load_dataset(cfg.datasets[0].path, split="train")
|
# assumption that we are loading a previously saved/cached dataset
|
||||||
|
dataset = load_from_disk(cfg.datasets)
|
||||||
else:
|
else:
|
||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
ds: IterableDataset = load_dataset(
|
ds: IterableDataset = load_dataset(
|
||||||
@@ -288,7 +290,6 @@ def train(
|
|||||||
[_ for _ in constant_len_dataset]
|
[_ for _ in constant_len_dataset]
|
||||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||||
dataset.save_to_disk("data/last_run")
|
dataset.save_to_disk("data/last_run")
|
||||||
print(dataset)
|
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ install_requires =
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
flash-attn
|
flash-attn
|
||||||
|
einops
|
||||||
|
|
||||||
[options.packages.find]
|
[options.packages.find]
|
||||||
where = src
|
where = src
|
||||||
|
|||||||
@@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
|
|
||||||
if example:
|
if example:
|
||||||
input_ids = example["input_ids"]
|
# just going to drop data points that are too long
|
||||||
attention_mask = example["attention_mask"]
|
if len(example["input_ids"]) <= self.seq_length:
|
||||||
labels = example["labels"]
|
input_ids = example["input_ids"]
|
||||||
|
attention_mask = example["attention_mask"]
|
||||||
|
labels = example["labels"]
|
||||||
|
|
||||||
if add_concat_token:
|
if add_concat_token:
|
||||||
input_ids.append(self.concat_token_id)
|
input_ids.append(self.concat_token_id)
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
labels.append(self.concat_token_id)
|
labels.append(self.concat_token_id)
|
||||||
|
|
||||||
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long)
|
||||||
attention_mask_with_concat = torch.tensor(
|
attention_mask_with_concat = torch.tensor(
|
||||||
attention_mask, dtype=torch.long
|
attention_mask, dtype=torch.long
|
||||||
)
|
)
|
||||||
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
labels_with_concat = torch.tensor(labels, dtype=torch.long)
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||||
buffer["labels"].append(labels_with_concat)
|
buffer["labels"].append(labels_with_concat)
|
||||||
buffer_len += len(input_ids)
|
buffer_len += len(input_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user