experimental llama 2 chat support (#296)
* experimental llama 2 chat support * few small fixes * llama2_chat * small fix to follow original implementation * small fixes and added fixtures/tests * fix -mixed up inference and finetuning conversations * args - small fix * small fix * small adjustment and warning * fix with pre-commit --------- Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bb53a165f5
commit
3392270544
205
src/axolotl/prompt_strategies/llama2_chat.py
Normal file
205
src/axolotl/prompt_strategies/llama2_chat.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Prompt Strategy for finetuning Llama2 chat models
|
||||
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
|
||||
|
||||
This implementation is based on the Vicuna PR and the fastchat repo, see also:
|
||||
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
|
||||
|
||||
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
|
||||
|
||||
E.g. in the config.yml:
|
||||
```
|
||||
datasets:
|
||||
- path: llama_finetune_train.jsonl
|
||||
type: llama2_chat
|
||||
```
|
||||
|
||||
The dataset itself should look like this:
|
||||
```
|
||||
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
|
||||
```
|
||||
in a jsonl file. The first message should be from the human, the second from gpt.
|
||||
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
|
||||
|
||||
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2ChatConversation:
|
||||
"""A class that manages prompt templates and keeps all conversation history.
|
||||
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
|
||||
|
||||
name: str = "llama2"
|
||||
# The system prompt
|
||||
system: str = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
roles: Sequence[str] = ("[INST]", "[/INST]")
|
||||
messages: List[List[str]] = field(default_factory=list)
|
||||
offset: int = 0
|
||||
sep = " "
|
||||
sep2 = " </s><s>"
|
||||
stop_token_ids = [2]
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if (i == len(self.messages) - 1) and (role == self.roles[0]):
|
||||
# last message is from user (due to length),
|
||||
# return prompt without it for training
|
||||
return ret
|
||||
if i == 0:
|
||||
ret += self.system + message.strip()
|
||||
else:
|
||||
ret += role + " " + message.strip() + seps[i % 2]
|
||||
return ret
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.sequence_len = 4096
|
||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
conv = next(self.prompter.build_prompt(prompt))
|
||||
conversation_str = conv.get_prompt()
|
||||
|
||||
# Tokenize conversations
|
||||
input_ids = self.tokenizer(
|
||||
conversation_str,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.sequence_len,
|
||||
truncation=True,
|
||||
).input_ids[0]
|
||||
target = input_ids.clone()
|
||||
|
||||
# Mask targets. Only compute loss on the assistant outputs.
|
||||
sep = conv.roles[1]
|
||||
|
||||
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
|
||||
|
||||
turns = conversation_str.split(conv.sep2)
|
||||
cur_len = 1
|
||||
target[:cur_len] = IGNORE_TOKEN_ID
|
||||
for turn in turns:
|
||||
if turn == "":
|
||||
break
|
||||
turn_len = len(self.tokenizer(turn).input_ids)
|
||||
|
||||
parts = turn.split(sep)
|
||||
if len(parts) != 2:
|
||||
break
|
||||
parts[0] += sep
|
||||
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
|
||||
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
|
||||
|
||||
# Ignore the user instructions
|
||||
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
|
||||
cur_len += turn_len + 2 # due to length of role token
|
||||
|
||||
target[cur_len:] = IGNORE_TOKEN_ID
|
||||
|
||||
if cur_len < self.sequence_len:
|
||||
if cur_len != total_len:
|
||||
target[:] = IGNORE_TOKEN_ID
|
||||
logging.warning(
|
||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||
f" (ignored)"
|
||||
)
|
||||
|
||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
|
||||
input_ids = input_ids.tolist()
|
||||
target = target.tolist()
|
||||
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
|
||||
# follows the original llama implementation
|
||||
for i in range(2, total_len - 2):
|
||||
if input_ids[i] == 29961:
|
||||
input_ids[i] = 518
|
||||
if target[i] == 29961:
|
||||
target[i] = 518
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": target,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for Llama2 models.
|
||||
"""
|
||||
|
||||
system_prompt = (
|
||||
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||
)
|
||||
|
||||
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
|
||||
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
|
||||
source = source["conversations"] # fix data structure for datasets
|
||||
|
||||
# if system prompt provided, use it
|
||||
if source[0]["from"] == "system":
|
||||
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
|
||||
source = source[1:]
|
||||
else:
|
||||
system = self.system_prompt
|
||||
|
||||
conv = Llama2ChatConversation(system=system)
|
||||
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError
|
||||
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
if roles[source[0]["from"]] != conv.roles[0]:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2]
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
|
||||
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
|
||||
return LLama2ChatTokenizingStrategy(
|
||||
Llama2ChatPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
@@ -378,7 +378,7 @@ def load_prepare_datasets(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) < cfg.sequence_len
|
||||
if len(d["input_ids"]) <= cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
|
||||
Reference in New Issue
Block a user