Tokenization open assistant (#1)

* refactor prompt tokenization to more easily support open assistant

* add open assisstant handling, more logging, black formatting
This commit is contained in:
Wing Lian
2023-04-18 01:45:49 -04:00
committed by GitHub
parent eb808903e5
commit 87d7825435
2 changed files with 149 additions and 51 deletions

View File

@@ -37,6 +37,7 @@ from axolotl.prompt_tokenizers import (
ShareGPTPromptTokenizingStrategy,
LLAMA_DEFAULT_PAD_TOKEN,
GPTeacherPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
@@ -56,7 +57,15 @@ def setup_wandb_env_vars(cfg):
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, adapter="lora", inference: bool=False):
def load_model(
base_model,
base_model_config,
model_type,
tokenizer_type,
cfg,
adapter="lora",
inference: bool = False,
):
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
@@ -67,13 +76,17 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
replace_llama_attn_with_flash_attn()
torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,
torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
try:
if cfg.load_4bit:
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_int4_lora_model
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
replace_peft_model_with_int4_lora_model,
)
replace_peft_model_with_int4_lora_model()
from peft import (
@@ -92,18 +105,26 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
from huggingface_hub import snapshot_download
cache_model_path = Path(snapshot_download(base_model))
files = list(cache_model_path.glob('*.pt')) + list(cache_model_path.glob('*.safetensors')) + list(cache_model_path.glob('*.bin'))
files = (
list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.safetensors"))
+ list(cache_model_path.glob("*.bin"))
)
if len(files) > 0:
model_path = str(files[0])
else:
logging.warning("unable to find a cached model file, this will likely fail...")
logging.warning(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
model, tokenizer = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model,
model_path,
device_map=cfg.device_map,
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
is_v1_model=cfg.gptq_model_v1 if cfg.gptq_model_v1 is not None else True,
is_v1_model=cfg.gptq_model_v1
if cfg.gptq_model_v1 is not None
else True,
)
load_in_8bit = False
elif is_llama_derived_model:
@@ -120,7 +141,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
except:
except Exception as e:
logging.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
logging.exception(e)
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
@@ -145,7 +170,6 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -165,7 +189,12 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
)
if cfg.lora_model_dir:
model = PeftModel.from_pretrained(model, cfg.lora_model_dir, device_map = cfg.device_map, torch_dtype=torch.float16)
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
model = get_peft_model(model, lora_config)
@@ -174,9 +203,11 @@ def load_model(base_model, base_model_config, model_type, tokenizer_type, cfg, a
if cfg.load_4bit:
# Scales to half
logging.info('Fitting 4bit scales and zeros to half')
logging.info("Fitting 4bit scales and zeros to half")
for n, m in model.named_modules():
if 'Autograd4bitQuantLinear' in str(type(m)) or 'Linear4bitLt' in str(type(m)):
if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
type(m)
):
if hasattr(m, "is_v1_model") and m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
@@ -236,37 +267,44 @@ def check_dataset_labels(dataset, tokenizer):
def do_inference(cfg, model, tokenizer):
tokenizer.add_special_tokens({'unk_token': '<unk>'})
tokenizer.add_special_tokens({'bos_token': '<s>'})
tokenizer.add_special_tokens({'eos_token': '</s>'})
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
instruction = "Tell me a joke about dromedaries."
input = ""
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(instruction=instruction, input=input)
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format(
instruction=instruction, input=input
)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(inputs=batch["input_ids"].to("cuda"),
do_sample=True, use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False)
print(tokenizer.decode(generated['sequences'].cpu().tolist()[0]))
generated = model.generate(
inputs=batch["input_ids"].to("cuda"),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = [file for file in path.glob("*.yml")]
if not yaml_files:
raise ValueError("No YAML config files found in the specified directory. Are you using a .yml extension?")
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
@@ -376,6 +414,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
return trainer
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
@@ -420,7 +459,13 @@ def train(
# Load the model and tokenizer
logging.info("loading model, tokenizer, and lora_config...")
model, tokenizer, lora_config = load_model(
cfg.base_model, cfg.base_model_config, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter, inference=("inference" in kwargs)
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
cfg.tokenizer_type,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
)
if "inference" in kwargs:
@@ -428,10 +473,26 @@ def train(
do_inference(cfg, model, tokenizer)
return
max_packed_sequence_len = cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
ds_hash = str(
md5(
(
str(max_packed_sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
).encode("utf-8")
).hexdigest()
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
)
if any(prepared_ds_path.glob("*")):
logging.info("Loading prepared dataset from disk...")
@@ -464,9 +525,18 @@ def train(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
elif d.type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
GPTeacherPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
@@ -476,13 +546,17 @@ def train(
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
constant_len_dataset = ConstantLengthDataset(
tokenizer, datasets, seq_length=max_packed_sequence_len,
tokenizer,
datasets,
seq_length=max_packed_sequence_len,
)
logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list(
[_ for _ in constant_len_dataset]
).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42)
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
test_size=cfg.val_set_size, shuffle=True, seed=42
)
if cfg.local_rank == 0:
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
@@ -525,7 +599,9 @@ def train(
if cfg.local_rank == 0:
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
logging.info(
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
)
model.save_pretrained(cfg.output_dir)

View File

@@ -31,14 +31,18 @@ class PromptTokenizingStrategy(abc.ABC):
pass
class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
raise NotImplementedError
def tokenize_prompt(self, prompt):
full_prompt = self._tokenize_full_prompt(prompt)
instruction, input, response = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, input, response)
tokenized_full_prompt = self._tokenize(full_prompt)
if not self.train_on_inputs:
user_prompt = self.prompter.build_prompt(
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
instruction,
input,
)
tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
@@ -49,11 +53,11 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return tokenized_full_prompt
def _tokenize_full_prompt(self, prompt):
def _build_full_prompt(self, instruction, input, response):
return self.prompter.build_prompt(
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
instruction,
input,
response,
)
def _tokenize(self, prompt, add_eos_token=True):
@@ -76,11 +80,29 @@ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
return result
class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
def _tokenize_full_prompt(self, prompt):
return self.prompter.build_prompt(
class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["instruction"],
prompt["input"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
)
class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["INSTRUCTION"],
"",
prompt["RESPONSE"],
)
class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
def parse_instruction_fields(self, prompt) -> (str, str, str):
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["response"],
)