Allow "weight: 0" in messages to mask them (#1703)

Allow in message objects the additional key `weight`, which can be set
to 0 (or 1) to cause that message to be masked out (or left unmasked)
for training (similar to [1]). This is helpful for training the model to be robust and
capable of error recovery upon a bad assistant message.
A missing `weight` key defaults to weight 1, to guarantee downward compatibility.

[1]: https://github.com/mistralai/mistral-finetune
This commit is contained in:
DavidFarago
2024-06-20 16:05:16 +02:00
committed by GitHub
parent 4de4b4089f
commit 559562d790
4 changed files with 221 additions and 7 deletions

View File

@@ -143,6 +143,9 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]

View File

@@ -377,7 +377,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
LOG.warning(f"expected tuple, got {part}")
continue
role, content = part
if len(part) <= 2:
role, content = part
weight = 1
else:
role, content, weight = part
# Uses "in" because role contains extra characters
input_turn = any(r.lower() in role.lower() for r in input_roles)
@@ -403,7 +407,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
@@ -439,13 +443,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
if weight == 0:
# everything from this is masked out from the labels
# (role is masked out too because it makes no sense if contents is masked out)
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels

View File

@@ -319,6 +319,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
conv = self._conversation.copy()
original_source = source.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
@@ -360,8 +361,27 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
return conv.get_turns()
turns = list(conv.get_turns())
original_source_length = len(original_source)
assert len(turns) in [
original_source_length - 1,
original_source_length,
original_source_length + 1,
]
if len(turns) == original_source_length + 1:
original_source = [{"weight": None}] + original_source
elif len(turns) == original_source_length - 1:
original_source = original_source[1:]
return [
(*turn, weight)
for turn, weight in zip(
turns,
[
1 if "weight" not in e or e["weight"] is None else e["weight"]
for e in original_source
],
)
]
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)