fix: update bradleyterry to use new chat_template
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
|||||||
|
|
||||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||||
|
|||||||
@@ -2,13 +2,18 @@
|
|||||||
Bradley-Terry model with chat template prompt strategy.
|
Bradley-Terry model with chat template prompt strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from axolotl.prompt_strategies.chat_template import (
|
from axolotl.prompt_strategies.chat_template import (
|
||||||
ChatTemplatePrompter,
|
ChatTemplatePrompter,
|
||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.utils.chat_templates import get_chat_template
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
|
||||||
|
# Configure the logger
|
||||||
|
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
|
||||||
|
LOG.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||||
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
if prompt["system"]:
|
if prompt["system"]:
|
||||||
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
|
prompt[self.messages].append(
|
||||||
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
|
{"role": "system", "content": prompt["system"]}
|
||||||
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
|
)
|
||||||
|
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||||
|
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
self.messages = "rejected_messages"
|
self.messages = "rejected_messages"
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt[self.messages] = []
|
||||||
if prompt["system"]:
|
if prompt["system"]:
|
||||||
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
|
prompt[self.messages].append(
|
||||||
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
|
{"role": "system", "content": prompt["system"]}
|
||||||
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
|
)
|
||||||
|
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||||
|
prompt[self.messages].append(
|
||||||
|
{"role": "assistant", "content": prompt["rejected"]}
|
||||||
|
)
|
||||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
ds_cfg = ds_cfg or {}
|
ds_cfg = ds_cfg or {}
|
||||||
|
chat_template_string = get_chat_template_from_config(
|
||||||
|
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": get_chat_template(ds_cfg.get("chat_template", "chatml")),
|
"chat_template": chat_template_string,
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training_detail": ds_cfg.get(
|
||||||
"message_field_training_detail", "train_detail"
|
"message_field_training_detail", None
|
||||||
),
|
),
|
||||||
"roles": ds_cfg.get("roles"),
|
"roles": ds_cfg.get("roles"),
|
||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strategy_params = {
|
strategy_params = {
|
||||||
"train_on_inputs": cfg.train_on_inputs,
|
"train_on_inputs": cfg.train_on_inputs,
|
||||||
"sequence_len": cfg.sequence_len,
|
"sequence_len": cfg.sequence_len,
|
||||||
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
|
"roles_to_train": ds_cfg.get("roles_to_train", []),
|
||||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
"train_on_eos": ds_cfg.get("train_on_eos", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy = BTChatTemplateStrategy(
|
strategy = BTChatTemplateStrategy(
|
||||||
|
|||||||
Reference in New Issue
Block a user