Phi-3 conversation format, example training script and perplexity metric (#1582)
* phi-3 support and perplexity metric * phi-3 chat template * metrics updates * chore: lint * fix assertion on Tensor * fix tests since tokenization happens in the metric * fix perplexity value of shorter passage --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ class PromptStyle(Enum):
|
||||
INSTRUCT = "instruct"
|
||||
CHAT = "chat"
|
||||
CHATML = "chatml"
|
||||
PHI = "phi"
|
||||
|
||||
|
||||
class Prompter:
|
||||
@@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
prompt_style: Optional[str] = None
|
||||
|
||||
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
||||
def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
|
||||
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
|
||||
self.match_prompt_style()
|
||||
|
||||
@@ -52,16 +53,20 @@ class AlpacaPrompter(Prompter):
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.system_format = "{system}\n\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
elif self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
elif self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
elif self.prompt_style == PromptStyle.PHI.value:
|
||||
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
|
||||
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
|
||||
self.system_format = "<|system|>{system}\n"
|
||||
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
# returns the full prompt from instruction and optional input
|
||||
@@ -381,12 +386,14 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
role_key_tool=role_key_tool,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import traceback
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
@@ -30,6 +31,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -374,10 +376,14 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
def __maybe_load_metrics(self):
|
||||
metrics = {}
|
||||
for metric in self.cfg.eval_causal_lm_metrics:
|
||||
try:
|
||||
metrics[metric] = evaluate.load(metric)
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"{metric}: {exc.args}")
|
||||
if metric == "perplexity":
|
||||
max_seq_len = self.cfg.eval_max_new_tokens
|
||||
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
||||
else:
|
||||
try:
|
||||
metrics[metric] = evaluate.load(metric)
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"{metric}: {exc.args}")
|
||||
return metrics
|
||||
|
||||
def on_evaluate(
|
||||
@@ -421,13 +427,20 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
# safely compute a metric and return the score if the format is correct
|
||||
metric_score = None
|
||||
try:
|
||||
metric_score = metric.compute(**kwargs)
|
||||
# Only pass the kwargs that are in the metric's feature list
|
||||
metric_kwargs = {
|
||||
k: kwargs[k]
|
||||
for k in metric._feature_names() # pylint: disable=protected-access
|
||||
if k in kwargs
|
||||
}
|
||||
metric_score = metric.compute(**metric_kwargs)
|
||||
return (
|
||||
metric_score["score"]
|
||||
if "score" in metric_score
|
||||
else metric_score["mean_score"]
|
||||
)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
traceback.print_exc()
|
||||
LOG.debug(
|
||||
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
||||
)
|
||||
@@ -443,11 +456,12 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
predictions=predictions,
|
||||
sources=sources,
|
||||
)
|
||||
score = score or compute(
|
||||
metric,
|
||||
references=[[r] for r in references],
|
||||
predictions=predictions,
|
||||
)
|
||||
if score is None:
|
||||
score = compute(
|
||||
metric,
|
||||
references=[[r] for r in references],
|
||||
predictions=predictions,
|
||||
)
|
||||
scores[metric_name] = score
|
||||
return scores
|
||||
|
||||
|
||||
76
src/axolotl/utils/callbacks/perplexity.py
Normal file
76
src/axolotl/utils/callbacks/perplexity.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""callback to calculate perplexity as an evaluation metric."""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
|
||||
This is a custom variant that doesn't re-tokenize the input or re-load the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_len: int,
|
||||
stride: int = 512,
|
||||
) -> None:
|
||||
self.max_seq_len = max_seq_len
|
||||
self.stride = stride
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = model.device
|
||||
self.name = "perplexity"
|
||||
|
||||
def _feature_names(self) -> List[str]:
|
||||
return ["references"]
|
||||
|
||||
def compute(
|
||||
self,
|
||||
references: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute perplexity in a fixed length sliding window across the sequence.
|
||||
"""
|
||||
assert references is not None, "Missing parameter: references"
|
||||
|
||||
references_tokenized = self.tokenizer(
|
||||
references, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
||||
input_ids = input_ids.to(self.device)
|
||||
|
||||
sequence_length = input_ids.size(1)
|
||||
|
||||
losses = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
||||
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
||||
trg_len = end_loc - prev_end_loc
|
||||
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
||||
labels_slice = input_ids_slice.clone()
|
||||
labels_slice[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs: CausalLMOutput = self.model(
|
||||
input_ids=input_ids_slice, labels=labels_slice
|
||||
)
|
||||
|
||||
losses.append(outputs.loss)
|
||||
|
||||
prev_end_loc = end_loc
|
||||
if end_loc == sequence_length:
|
||||
break
|
||||
|
||||
perplexity = torch.exp(torch.stack(losses).mean()).item()
|
||||
|
||||
return {
|
||||
"score": perplexity,
|
||||
}
|
||||
@@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
||||
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
}
|
||||
|
||||
if user_choice in templates:
|
||||
|
||||
@@ -10,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
SUPPORTED_METRICS,
|
||||
AxolotlConfigWCapabilities,
|
||||
AxolotlInputConfig,
|
||||
)
|
||||
@@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
|
||||
)
|
||||
|
||||
if cfg.eval_causal_lm_metrics:
|
||||
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
|
||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
|
||||
# TODO
|
||||
|
||||
@@ -17,6 +17,8 @@ from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
@@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -1073,13 +1076,12 @@ class AxolotlInputConfig(
|
||||
)
|
||||
|
||||
if data.get("eval_causal_lm_metrics"):
|
||||
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
||||
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
|
||||
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -474,12 +474,16 @@ def load_prepare_datasets(
|
||||
index=cfg.dataset_shard_idx,
|
||||
)
|
||||
|
||||
if split == "train" and cfg.val_set_size:
|
||||
val_set_size = (
|
||||
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
||||
)
|
||||
|
||||
if split == "train" and val_set_size:
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ str(val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
@@ -488,7 +492,7 @@ def load_prepare_datasets(
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ str(val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
@@ -498,9 +502,7 @@ def load_prepare_datasets(
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=int(cfg.val_set_size)
|
||||
if cfg.val_set_size == int(cfg.val_set_size)
|
||||
else cfg.val_set_size,
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
@@ -535,6 +537,10 @@ def get_dataset_wrapper(
|
||||
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
|
||||
Reference in New Issue
Block a user