Allow "weight: 0" in messages to mask them (#1703)
Allow in message objects the additional key `weight`, which can be set to 0 (or 1) to cause that message to be masked out (or left unmasked) for training (similar to [1]). This is helpful for training the model to be robust and capable of error recovery upon a bad assistant message. A missing `weight` key defaults to weight 1, to guarantee downward compatibility. [1]: https://github.com/mistralai/mistral-finetune
This commit is contained in:
@@ -52,6 +52,51 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset_with_weights")
|
||||
def fixture_sharegpt_dataset_with_weights():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "repeat",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "hello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "rehello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "rehello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "goodbye",
|
||||
"weight": 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
@@ -162,6 +207,46 @@ class TestSharegptLlama3:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_tokenization_with_weights(
|
||||
self, sharegpt_dataset_with_weights, llama3_tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="llama3",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
271, 31724, 128009, # sys prompt, eot
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatML:
|
||||
"""
|
||||
@@ -197,7 +282,40 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_no_double_im_end_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
# 28705, 13, is " \n"
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
@@ -225,7 +343,39 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
def test_no_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
@@ -253,6 +403,38 @@ class TestSharegptChatML:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
|
||||
Reference in New Issue
Block a user