improve check for base case
This commit is contained in:
@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
# Old simple legacy behavior that works reliably.
|
# Old simple legacy behavior that works reliably.
|
||||||
if (
|
if (
|
||||||
not self.roles_to_train
|
(not self.roles_to_train or self.roles_to_train == ["assistant"])
|
||||||
and not self.train_on_eos
|
and not self.train_on_eos
|
||||||
and not self.prompter.message_field_training
|
and not self.prompter.message_field_training
|
||||||
and not self.prompter.message_field_training_detail
|
and not self.prompter.message_field_training_detail
|
||||||
@@ -487,7 +487,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
|||||||
strategy_params = {
|
strategy_params = {
|
||||||
"train_on_inputs": cfg.train_on_inputs,
|
"train_on_inputs": cfg.train_on_inputs,
|
||||||
"sequence_len": cfg.sequence_len,
|
"sequence_len": cfg.sequence_len,
|
||||||
"roles_to_train": ds_cfg.get("roles_to_train", None),
|
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user