feat: do not find turn indices if turn is not trainable (#2696)
* feat: do not find turn indices if turn is not trainable * fix: handle edge case where train on eos/eot is all * fix: improve warning message
This commit is contained in:
@@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
LOG.debug(f"Should train: {should_train}")
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
|
||||||
|
# turn not trainable, skip having to find the turn indices
|
||||||
|
# unless last turn and train_on_eos/train_on_eot is all
|
||||||
|
if not should_train and (
|
||||||
|
self.train_on_eos != "all" and self.train_on_eot != "all"
|
||||||
|
):
|
||||||
|
if index == len(turns) - 1:
|
||||||
|
LOG.warning(
|
||||||
|
"Last turn is not trainable, skipping having to find the turn indices. "
|
||||||
|
"This may cause incorrect last EOT/EOS token to be unmasked."
|
||||||
|
"This is likely a dataset design issue. Please ensure last turn is trainable."
|
||||||
|
)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||||
|
|
||||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|||||||
Reference in New Issue
Block a user