diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 849d84e45..4457c50be 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -6,7 +6,7 @@ import logging from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig -LOG = logging.getLogger("axolotl.prompt_strategies") +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry") def load(strategy, tokenizer, cfg, ds_cfg): diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index d3cd5c2f5..fa85cdcb2 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -2,13 +2,18 @@ Bradley-Terry model with chat template prompt strategy. """ +import logging from typing import Any, Dict, Optional from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, ) -from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.chat_templates import get_chat_template_from_config + +# Configure the logger +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template") +LOG.setLevel(logging.INFO) class BTChatTemplateStrategy(ChatTemplateStrategy): @@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) chosen_tokenized = super().tokenize_prompt(prompt) self.messages = "rejected_messages" # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append( + {"role": "assistant", "content": prompt["rejected"]} + ) rejected_tokenized = super().tokenize_prompt(prompt) return { @@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) prompter_params = { "tokenizer": tokenizer, - "chat_template": get_chat_template(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "chat_template": chat_template_string, + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", None ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = BTChatTemplateStrategy(