feat: do not find turn indices if turn is not trainable (#2696)

* feat: do not find turn indices if turn is not trainable

* fix: handle edge case where train on eos/eot is all

* fix: improve warning message
This commit is contained in:
NanoCode012
2025-05-22 19:19:59 +07:00
committed by GitHub
parent 798b5f5cfd
commit aa0492c366

View File

@@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(f"Should train: {should_train}")
# turn not trainable, skip having to find the turn indices
# unless last turn and train_on_eos/train_on_eot is all
if not should_train and (
self.train_on_eos != "all" and self.train_on_eot != "all"
):
if index == len(turns) - 1:
LOG.warning(
"Last turn is not trainable, skipping having to find the turn indices. "
"This may cause incorrect last EOT/EOS token to be unmasked."
"This is likely a dataset design issue. Please ensure last turn is trainable."
)
continue
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")