Allow "weight: 0" in messages to mask them (#1703)
Allow in message objects the additional key `weight`, which can be set to 0 (or 1) to cause that message to be masked out (or left unmasked) for training (similar to [1]). This is helpful for training the model to be robust and capable of error recovery upon a bad assistant message. A missing `weight` key defaults to weight 1, to guarantee downward compatibility. [1]: https://github.com/mistralai/mistral-finetune
This commit is contained in:
@@ -143,6 +143,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||
),
|
||||
"value": t[value_key],
|
||||
"weight": 1
|
||||
if "weight" not in t or t["weight"] is None
|
||||
else t["weight"],
|
||||
}
|
||||
for t in conversations
|
||||
]
|
||||
|
||||
@@ -377,7 +377,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
role, content = part
|
||||
if len(part) <= 2:
|
||||
role, content = part
|
||||
weight = 1
|
||||
else:
|
||||
role, content, weight = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||
@@ -403,7 +407,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if self.train_on_inputs:
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
@@ -439,13 +443,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
if weight == 0:
|
||||
# everything from this is masked out from the labels
|
||||
# (role is masked out too because it makes no sense if contents is masked out)
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
elif empty_role:
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
if self.train_on_inputs:
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
|
||||
@@ -319,6 +319,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
|
||||
conv = self._conversation.copy()
|
||||
|
||||
original_source = source.copy()
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
@@ -360,8 +361,27 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
return conv.get_turns()
|
||||
turns = list(conv.get_turns())
|
||||
original_source_length = len(original_source)
|
||||
assert len(turns) in [
|
||||
original_source_length - 1,
|
||||
original_source_length,
|
||||
original_source_length + 1,
|
||||
]
|
||||
if len(turns) == original_source_length + 1:
|
||||
original_source = [{"weight": None}] + original_source
|
||||
elif len(turns) == original_source_length - 1:
|
||||
original_source = original_source[1:]
|
||||
return [
|
||||
(*turn, weight)
|
||||
for turn, weight in zip(
|
||||
turns,
|
||||
[
|
||||
1 if "weight" not in e or e["weight"] is None else e["weight"]
|
||||
for e in original_source
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
@@ -52,6 +52,51 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset_with_weights")
|
||||
def fixture_sharegpt_dataset_with_weights():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "repeat",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "hello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "rehello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "rehello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "goodbye",
|
||||
"weight": 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
@@ -162,6 +207,46 @@ class TestSharegptLlama3:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_tokenization_with_weights(
|
||||
self, sharegpt_dataset_with_weights, llama3_tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="llama3",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
271, 31724, 128009, # sys prompt, eot
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatML:
|
||||
"""
|
||||
@@ -197,7 +282,40 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_no_double_im_end_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
# 28705, 13, is " \n"
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
@@ -225,7 +343,39 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_no_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
@@ -253,6 +403,38 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
|
||||
Reference in New Issue
Block a user