various bugfixes
This commit is contained in:
@@ -11,7 +11,7 @@ import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from attrdict import AttrDefault
|
||||
from datasets import load_dataset, IterableDataset, Dataset
|
||||
from datasets import load_dataset, IterableDataset, Dataset, load_from_disk
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
get_peft_model,
|
||||
@@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
||||
if adapter != "lora":
|
||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||
if "llama" in base_model:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
replace_llama_attn_with_flash_attn()
|
||||
if cfg.device not in ["mps", "cpu"]:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
replace_llama_attn_with_flash_attn()
|
||||
|
||||
try:
|
||||
if "llama" in base_model:
|
||||
@@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
|
||||
except:
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
||||
|
||||
if tokenizer.__class__.__name__ == "LlamaTokenizer":
|
||||
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
@@ -255,8 +256,9 @@ def train(
|
||||
return
|
||||
|
||||
datasets = []
|
||||
if len(cfg.datasets) == 1 and cfg.datasets[0].type == "arrow":
|
||||
dataset = load_dataset(cfg.datasets[0].path, split="train")
|
||||
if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str):
|
||||
# assumption that we are loading a previously saved/cached dataset
|
||||
dataset = load_from_disk(cfg.datasets)
|
||||
else:
|
||||
for d in cfg.datasets:
|
||||
ds: IterableDataset = load_dataset(
|
||||
@@ -288,7 +290,6 @@ def train(
|
||||
[_ for _ in constant_len_dataset]
|
||||
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
|
||||
dataset.save_to_disk("data/last_run")
|
||||
print(dataset)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
Reference in New Issue
Block a user