fix: chat_template masking due to truncation, consolidate turn build and keys within field (#2123) [skip ci]

* fix: chat_template masking due to truncation, consolidate turn build and keys within field

* fix: revert roles change

* fix: handling of training and training_detail

* fix: do not skip setting eos mask even if failed finding turn boundary

* fix: truncate reward modelling outputs
This commit is contained in:
NanoCode012
2024-12-10 01:49:38 +07:00
committed by GitHub
parent 3862267040
commit 5d6b088997
2 changed files with 112 additions and 72 deletions

View File

@@ -28,6 +28,8 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:return: :return:
""" """
max_length = self.prompter.max_length
self.messages = "chosen_messages" self.messages = "chosen_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -39,6 +41,16 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super().tokenize_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]
self.messages = "rejected_messages" self.messages = "rejected_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
@@ -52,6 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
) )
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super().tokenize_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]
return { return {
"input_ids_chosen": chosen_tokenized["input_ids"], "input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"], "attention_mask_chosen": chosen_tokenized["attention_mask"],
@@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"), "roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit. # we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1 "max_length": (
if not cfg.reward_model cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
else cfg.sequence_len, ),
} }
strategy_params = { strategy_params = {

View File

@@ -42,6 +42,7 @@ class ChatTemplatePrompter(Prompter):
"gpt": "assistant", "gpt": "assistant",
"system": "system", "system": "system",
} }
self.message_field_role = message_field_role self.message_field_role = message_field_role
self.message_field_content = message_field_content self.message_field_content = message_field_content
self.message_field_training = message_field_training self.message_field_training = message_field_training
@@ -53,21 +54,9 @@ class ChatTemplatePrompter(Prompter):
self.drop_system_message = drop_system_message self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None): def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor: if self.processor:
text = self.processor.apply_chat_template( text = self.processor.apply_chat_template(
turns, conversation,
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
@@ -76,8 +65,6 @@ class ChatTemplatePrompter(Prompter):
text=text, text=text,
images=images, images=images,
return_tensors="pt", return_tensors="pt",
truncation=True,
max_length=self.max_length,
) )
# workaround since processor works in batches instead of single examples # workaround since processor works in batches instead of single examples
for k, val in batch.items(): for k, val in batch.items():
@@ -88,9 +75,7 @@ class ChatTemplatePrompter(Prompter):
return batch return batch
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
turns, conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template, chat_template=self.chat_template,
) )
@@ -215,7 +200,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None, train_on_eos=None,
): ):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len) super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos self.train_on_eos = train_on_eos
self.images = "images" self.images = "images"
@@ -262,30 +254,28 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
turns = prompt[self.messages] turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids) labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1 last_eos_idx = -1
for index, turn in enumerate(turns): for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role) role = turn.get("role")
content = turn.get(self.prompter.message_field_content) content = turn.get("content")
train_turn = turn.get(self.prompter.message_field_training) train_turn = turn.get("training")
train_detail = turn.get(self.prompter.message_field_training_detail) train_detail = turn.get("training_detail")
LOG.debug( LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}" f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
) )
should_train = ( should_train = None
train_turn if train_turn is not None:
if train_turn is not None should_train = train_turn
else ( elif train_detail is not None:
bool(train_detail is not None) should_train = bool(train_detail)
if train_detail is not None else:
else self.train_on_inputs or role in self.roles_to_train should_train = self.train_on_inputs or role in self.roles_to_train
)
)
LOG.debug(f"Should train: {should_train}") LOG.debug(f"Should train: {should_train}")
@@ -293,6 +283,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
conversation_ids=input_ids, turn=index, turn_content=turn conversation_ids=input_ids, turn=index, turn_content=turn
) )
if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1: if should_train and turn_start_idx != -1 and turn_end_idx != -1:
@@ -313,7 +306,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels[turn_start_idx:turn_end_idx] = input_ids[ labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx turn_start_idx:turn_end_idx
] ]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}") LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels after processing turn {index}: {labels}") LOG.debug(f"Labels after processing turn {index}: {labels}")
@@ -351,52 +346,73 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i return i
return -1 return -1
def find_turn(self, conversation_ids, turn, turn_content): def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
""" """
Locate the starting and ending indices of the specified turn in a conversation. Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
""" """
content = turn_content.get(self.prompter.message_field_content, "") content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False) content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens if not content_ids:
for i, token_id in enumerate(conversation_ids): LOG.warning(f"Empty content for turn {turn}")
if token_id == eos_token_id: return -1, -1
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation # For first turn, start from beginning
start_idx = -1 if turn == 0:
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1): start_search_idx = 0
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else: else:
end_idx = -1 # For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
return start_idx, end_idx for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break
# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1
if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1
# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)
return -1, -1
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt[self.messages] turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
def get_images(self, prompt): def get_images(self, prompt):
return prompt.get(self.images, None) return prompt.get(self.images, None)