From aa0492c366d32645481c80b5c60a86f53f7670d7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 19:19:59 +0700 Subject: [PATCH] feat: do not find turn indices if turn is not trainable (#2696) * feat: do not find turn indices if turn is not trainable * fix: handle edge case where train on eos/eot is all * fix: improve warning message --- src/axolotl/prompt_strategies/chat_template.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 638cee559..047a66e94 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): LOG.debug(f"Should train: {should_train}") + # turn not trainable, skip having to find the turn indices + # unless last turn and train_on_eos/train_on_eot is all + if not should_train and ( + self.train_on_eos != "all" and self.train_on_eot != "all" + ): + if index == len(turns) - 1: + LOG.warning( + "Last turn is not trainable, skipping having to find the turn indices. " + "This may cause incorrect last EOT/EOS token to be unmasked." + "This is likely a dataset design issue. Please ensure last turn is trainable." + ) + + continue + turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")