fix: update bradleyterry to use new chat_template

This commit is contained in:
NanoCode012
2024-10-16 20:42:14 +07:00
parent 207e7627f9
commit 28e7e444ee
2 changed files with 29 additions and 15 deletions

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies") LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg): def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy. Bradley-Terry model with chat template prompt strategy.
""" """
import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import ( from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter, ChatTemplatePrompter,
ChatTemplateStrategy, ChatTemplateStrategy,
) )
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy): class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
if prompt["system"]: if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) prompt[self.messages].append(
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) {"role": "system", "content": prompt["system"]}
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) )
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages" self.messages = "rejected_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
if prompt["system"]: if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) prompt[self.messages].append(
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) {"role": "system", "content": prompt["system"]}
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) )
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super().tokenize_prompt(prompt)
return { return {
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": get_chat_template(ds_cfg.get("chat_template", "chatml")), "chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "value"), "message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", "training"), "message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get( "message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail" "message_field_training_detail", None
), ),
"roles": ds_cfg.get("roles"), "roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = { strategy_params = {
"train_on_inputs": cfg.train_on_inputs, "train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len, "sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), "roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"), "train_on_eos": ds_cfg.get("train_on_eos", None),
} }
strategy = BTChatTemplateStrategy( strategy = BTChatTemplateStrategy(