improve check for base case

This commit is contained in:
Wing Lian
2025-01-24 12:02:34 -05:00
parent 94c226edb3
commit 6c49083d8b

View File

@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
(not self.roles_to_train or self.roles_to_train == ["assistant"])
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
@@ -487,7 +487,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", None),
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}