fix: chat_template masking due to truncation, consolidate turn build and keys within field (#2123) [skip ci]
* fix: chat_template masking due to truncation, consolidate turn build and keys within field * fix: revert roles change * fix: handling of training and training_detail * fix: do not skip setting eos mask even if failed finding turn boundary * fix: truncate reward modelling outputs
This commit is contained in:
@@ -28,6 +28,8 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_length = self.prompter.max_length
|
||||||
|
|
||||||
self.messages = "chosen_messages"
|
self.messages = "chosen_messages"
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
@@ -39,6 +41,16 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
|
LOG.warning(
|
||||||
|
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||||
|
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
|
||||||
self.messages = "rejected_messages"
|
self.messages = "rejected_messages"
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
@@ -52,6 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
)
|
)
|
||||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
|
LOG.warning(
|
||||||
|
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||||
|
)
|
||||||
|
|
||||||
|
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
|
||||||
|
:max_length
|
||||||
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||||
@@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
"roles": ds_cfg.get("roles"),
|
"roles": ds_cfg.get("roles"),
|
||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
"max_length": cfg.sequence_len + 1
|
"max_length": (
|
||||||
if not cfg.reward_model
|
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
|
||||||
else cfg.sequence_len,
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy_params = {
|
strategy_params = {
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
"gpt": "assistant",
|
"gpt": "assistant",
|
||||||
"system": "system",
|
"system": "system",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.message_field_role = message_field_role
|
self.message_field_role = message_field_role
|
||||||
self.message_field_content = message_field_content
|
self.message_field_content = message_field_content
|
||||||
self.message_field_training = message_field_training
|
self.message_field_training = message_field_training
|
||||||
@@ -53,21 +54,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||||
turns = [
|
|
||||||
{
|
|
||||||
"role": self.roles[t[self.message_field_role]],
|
|
||||||
"content": t[self.message_field_content],
|
|
||||||
"training": t.get(self.message_field_training, None),
|
|
||||||
}
|
|
||||||
for t in conversation
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.drop_system_message and turns[0]["role"] == "system":
|
|
||||||
turns = turns[1:]
|
|
||||||
|
|
||||||
if self.processor:
|
if self.processor:
|
||||||
text = self.processor.apply_chat_template(
|
text = self.processor.apply_chat_template(
|
||||||
turns,
|
conversation,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
@@ -76,8 +65,6 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
text=text,
|
text=text,
|
||||||
images=images,
|
images=images,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
)
|
)
|
||||||
# workaround since processor works in batches instead of single examples
|
# workaround since processor works in batches instead of single examples
|
||||||
for k, val in batch.items():
|
for k, val in batch.items():
|
||||||
@@ -88,9 +75,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
turns,
|
conversation,
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
)
|
)
|
||||||
@@ -215,7 +200,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
train_on_eos=None,
|
train_on_eos=None,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
|
||||||
|
self.roles_to_train = []
|
||||||
|
if roles_to_train:
|
||||||
|
# map roles if exist in prompter.roles else use the role as is
|
||||||
|
self.roles_to_train = [
|
||||||
|
prompter.roles.get(role, role) for role in roles_to_train
|
||||||
|
]
|
||||||
|
|
||||||
self.train_on_eos = train_on_eos
|
self.train_on_eos = train_on_eos
|
||||||
self.images = "images"
|
self.images = "images"
|
||||||
|
|
||||||
@@ -262,30 +254,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
turns = prompt[self.messages]
|
turns = self.get_conversation_thread(prompt)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
last_eos_idx = -1
|
last_eos_idx = -1
|
||||||
for index, turn in enumerate(turns):
|
for index, turn in enumerate(turns):
|
||||||
role = turn.get(self.prompter.message_field_role)
|
role = turn.get("role")
|
||||||
content = turn.get(self.prompter.message_field_content)
|
content = turn.get("content")
|
||||||
train_turn = turn.get(self.prompter.message_field_training)
|
train_turn = turn.get("training")
|
||||||
train_detail = turn.get(self.prompter.message_field_training_detail)
|
train_detail = turn.get("training_detail")
|
||||||
|
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||||
)
|
)
|
||||||
|
|
||||||
should_train = (
|
should_train = None
|
||||||
train_turn
|
if train_turn is not None:
|
||||||
if train_turn is not None
|
should_train = train_turn
|
||||||
else (
|
elif train_detail is not None:
|
||||||
bool(train_detail is not None)
|
should_train = bool(train_detail)
|
||||||
if train_detail is not None
|
else:
|
||||||
else self.train_on_inputs or role in self.roles_to_train
|
should_train = self.train_on_inputs or role in self.roles_to_train
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug(f"Should train: {should_train}")
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
|
||||||
@@ -293,6 +283,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
conversation_ids=input_ids, turn=index, turn_content=turn
|
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if turn_start_idx == -1 or turn_end_idx == -1:
|
||||||
|
LOG.warning(f"Failed to find boundaries for turn {index}")
|
||||||
|
|
||||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
@@ -313,7 +306,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
labels[turn_start_idx:turn_end_idx] = input_ids[
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||||
turn_start_idx:turn_end_idx
|
turn_start_idx:turn_end_idx
|
||||||
]
|
]
|
||||||
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
LOG.debug(
|
||||||
|
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
@@ -351,52 +346,73 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return i
|
return i
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def find_turn(self, conversation_ids, turn, turn_content):
|
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
|
||||||
"""
|
"""
|
||||||
Locate the starting and ending indices of the specified turn in a conversation.
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_ids (list[int]): Token IDs representing the conversation.
|
|
||||||
turn (int): The turn number to locate (based on EOS tokens).
|
|
||||||
turn_content (str): String containing the content of the turn.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
|
||||||
Returns (-1, -1) if the turn content is not found.
|
|
||||||
"""
|
"""
|
||||||
content = turn_content.get(self.prompter.message_field_content, "")
|
content = turn_content.get("content")
|
||||||
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||||
|
|
||||||
eos_token_id = self.tokenizer.eos_token_id
|
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
|
||||||
eos_count = 0
|
|
||||||
start_search_idx = 0
|
|
||||||
|
|
||||||
# Locate the starting index after the specified number of EOS tokens
|
if not content_ids:
|
||||||
for i, token_id in enumerate(conversation_ids):
|
LOG.warning(f"Empty content for turn {turn}")
|
||||||
if token_id == eos_token_id:
|
return -1, -1
|
||||||
eos_count += 1
|
|
||||||
if eos_count == turn:
|
|
||||||
start_search_idx = (
|
|
||||||
i + 1
|
|
||||||
) # Start searching after the specified turn's EOS token
|
|
||||||
break
|
|
||||||
|
|
||||||
# Find the start index of the content within the conversation
|
# For first turn, start from beginning
|
||||||
start_idx = -1
|
if turn == 0:
|
||||||
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
start_search_idx = 0
|
||||||
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
|
||||||
start_idx = i
|
|
||||||
break
|
|
||||||
|
|
||||||
if start_idx != -1:
|
|
||||||
end_idx = start_idx + len(content_ids)
|
|
||||||
else:
|
else:
|
||||||
end_idx = -1
|
# For subsequent turns, find the previous EOS token
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
eos_count = 0
|
||||||
|
start_search_idx = 0
|
||||||
|
|
||||||
return start_idx, end_idx
|
for i, token_id in enumerate(conversation_ids):
|
||||||
|
if token_id == eos_token_id:
|
||||||
|
eos_count += 1
|
||||||
|
if eos_count == turn: # Find the nth EOS token where n = turn
|
||||||
|
start_search_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# we can optimize this to only search for a few tokens from start_search_idx
|
||||||
|
# but it would risk missing the content if it's not found within the first few tokens or
|
||||||
|
# if start_search_idx cannot be found above.
|
||||||
|
last_index = len(conversation_ids) - len(content_ids) + 1
|
||||||
|
|
||||||
|
if last_index < start_search_idx:
|
||||||
|
LOG.warning(
|
||||||
|
f"last_index to search is less than start_search_idx for turn {turn}"
|
||||||
|
)
|
||||||
|
return -1, -1
|
||||||
|
|
||||||
|
# Search for content starting from start_search_idx
|
||||||
|
first_elem = content_ids[0]
|
||||||
|
for i in range(start_search_idx, last_index):
|
||||||
|
# Quick check of first element before doing full comparison
|
||||||
|
if conversation_ids[i] == first_elem:
|
||||||
|
# Check if the rest of the content matches
|
||||||
|
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||||
|
LOG.debug(f"Found turn {turn} content at position {i}")
|
||||||
|
return i, i + len(content_ids)
|
||||||
|
|
||||||
|
return -1, -1
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
turns = [
|
||||||
|
{
|
||||||
|
"role": self.prompter.roles[t[self.prompter.message_field_role]],
|
||||||
|
"content": t[self.prompter.message_field_content],
|
||||||
|
"training": t.get(self.prompter.message_field_training),
|
||||||
|
"training_detail": t.get(self.prompter.message_field_training_detail),
|
||||||
|
}
|
||||||
|
for t in prompt[self.messages]
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||||
|
turns = turns[1:]
|
||||||
|
|
||||||
|
return turns
|
||||||
|
|
||||||
def get_images(self, prompt):
|
def get_images(self, prompt):
|
||||||
return prompt.get(self.images, None)
|
return prompt.get(self.images, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user