feat: do not find turn indices if turn is not trainable (#2696)
* feat: do not find turn indices if turn is not trainable * fix: handle edge case where train on eos/eot is all * fix: improve warning message
This commit is contained in:
@@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
LOG.debug(f"Should train: {should_train}")
|
||||
|
||||
# turn not trainable, skip having to find the turn indices
|
||||
# unless last turn and train_on_eos/train_on_eot is all
|
||||
if not should_train and (
|
||||
self.train_on_eos != "all" and self.train_on_eot != "all"
|
||||
):
|
||||
if index == len(turns) - 1:
|
||||
LOG.warning(
|
||||
"Last turn is not trainable, skipping having to find the turn indices. "
|
||||
"This may cause incorrect last EOT/EOS token to be unmasked."
|
||||
"This is likely a dataset design issue. Please ensure last turn is trainable."
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
Reference in New Issue
Block a user