fix sharegpt handling from hf, don't worry about loading llama if using earlier transformers release

This commit is contained in:
Wing Lian
2023-04-20 09:19:46 -04:00
parent 8e2a5609b3
commit 8d437853c8
4 changed files with 29 additions and 7 deletions

View File

@@ -5,7 +5,8 @@ load_in_8bit: true
datasets: datasets:
- path: data/alpaca_data_gpt4.jsonl - path: data/alpaca_data_gpt4.jsonl
type: alpaca type: alpaca
- path: data/vicuna_cleaned.jsonl - path: anon8231489123/ShareGPT_Vicuna_unfiltered
data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
type: sharegpt type: sharegpt
- path: data/gpt4-instruct-similarity-0.6-dataset.jsonl - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
type: gpteacher type: gpteacher
@@ -30,6 +31,8 @@ wandb_log_model: checkpoint
output_dir: ./lora-llama-alpaca output_dir: ./lora-llama-alpaca
batch_size: 128 batch_size: 128
micro_batch_size: 16 micro_batch_size: 16
warmup_steps: 1000
save_steps:
num_epochs: 5 num_epochs: 5
learning_rate: 0.00003 learning_rate: 0.00003
train_on_inputs: false train_on_inputs: false

View File

@@ -128,6 +128,10 @@ conv_vicuna_v1_1 = Conversation(
class ShareGPTPrompter: class ShareGPTPrompter:
def build_prompt(self, source, tokenizer): def build_prompt(self, source, tokenizer):
# ignore the system prompt if provided
if source[0]["from"] == "system":
source.pop(0)
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations

View File

@@ -3,6 +3,7 @@ from hashlib import md5
from pathlib import Path from pathlib import Path
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
from huggingface_hub import hf_hub_download
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
@@ -50,6 +51,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
logging.info("Loading raw datasets...") logging.info("Loading raw datasets...")
datasets = [] datasets = []
for d in cfg.datasets: for d in cfg.datasets:
ds = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset(d.path, streaming=True) load_dataset(d.path, streaming=True)
@@ -63,9 +65,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
"json", data_files=d.path, streaming=True, split=None "json", data_files=d.path, streaming=True, split=None
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset(d.path, streaming=True) if d.data_files:
ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
else:
ds = load_dataset(d.path, streaming=True)
else: else:
raise Exception(f"unhandled dataset load for {d.path}") fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
ds = load_dataset("json", data_files=fp, streaming=True, split=None)
if not ds:
raise Exception("unhandled dataset load")
if d.type == "alpaca": if d.type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy( ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -111,6 +119,8 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
seq_length=max_packed_sequence_len, seq_length=max_packed_sequence_len,
) )
logging.info("merging, packing, shuffling, and splitting master dataset") logging.info("merging, packing, shuffling, and splitting master dataset")
# TODO don't split dataset here, shuffle and save first, then split, that way we can
# re-split when loading again
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split( dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
test_size=cfg.val_set_size, shuffle=True, seed=42 test_size=cfg.val_set_size, shuffle=True, seed=42
) )

View File

@@ -7,11 +7,16 @@ import torch
import transformers import transformers
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
AutoTokenizer, AutoTokenizer,
PreTrainedModel, PreTrainedModel,
) )
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -95,7 +100,7 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif is_llama_derived_model: elif is_llama_derived_model and "LlamaForCausalLM" in globals():
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit, load_in_8bit=cfg.load_in_8bit,
@@ -130,7 +135,7 @@ def load_model(
if not tokenizer: if not tokenizer:
try: try:
if is_llama_derived_model: if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(model) tokenizer = LlamaTokenizer.from_pretrained(model)
else: else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)