ORPO (#1419)
* orpo trainer * rl handling for orpo * support for remove_unused_columns * orpo fixes * fix loader for orpo * chore: lint * fix default for remove_unused_columns * roll ORPO into the main AxolotlTrainer so it can be compatible with some of the other techniques like relora * better handling of system message for orpo * revert system prompt changes for chat templtes * no need for else condition * split dataset parsing into it's own component
This commit is contained in:
15
docs/rlhf.md
15
docs/rlhf.md
@@ -34,6 +34,21 @@ datasets:
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
#### ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
```yaml
|
||||
rl: orpo
|
||||
orpo_alpha: 0.1
|
||||
remove_unused_columns: false
|
||||
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: orpo.chat_template
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
```yaml
|
||||
datasets:
|
||||
|
||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl:
|
||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.rl:
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,10 +11,11 @@ import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -223,6 +227,9 @@ class AxolotlTrainer(Trainer):
|
||||
self.eval_data_collator = eval_data_collator
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
@@ -465,8 +472,112 @@ class AxolotlTrainer(Trainer):
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def orpo_compute_custom_loss(self, logits, labels):
|
||||
logits = logits.contiguous()
|
||||
loss = 0.0
|
||||
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def orpo_compute_logps(
|
||||
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
||||
):
|
||||
# Get the shape of chosen_attention_mask[:, :-1]
|
||||
chosen_shape = chosen_attention_mask[:, :-1].shape
|
||||
|
||||
# Calculate the padding size
|
||||
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
||||
|
||||
# Pad prompt_attention_mask with zeros to match the desired shape
|
||||
prompt_attention_mask_padded = torch.nn.functional.pad(
|
||||
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
||||
)
|
||||
|
||||
# Perform the subtraction operation
|
||||
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
||||
|
||||
per_token_logps = torch.gather(
|
||||
logits[:, :-1, :].log_softmax(-1),
|
||||
dim=2,
|
||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
||||
dtype=torch.float64
|
||||
) / mask.sum(dim=1).to(dtype=torch.float64)
|
||||
|
||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||
outputs_neg = model(
|
||||
**{
|
||||
"input_ids": inputs["rejected_input_ids"],
|
||||
"attention_mask": inputs["rejected_attention_mask"],
|
||||
"labels": inputs["rejected_labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
outputs_pos = model(
|
||||
**{
|
||||
"input_ids": inputs["input_ids"],
|
||||
"attention_mask": inputs["attention_mask"],
|
||||
"labels": inputs["labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Calculate NLL loss
|
||||
pos_loss = self.orpo_compute_custom_loss(
|
||||
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
||||
)
|
||||
|
||||
# Calculate Log Probability
|
||||
pos_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
||||
chosen_inputs=inputs["input_ids"],
|
||||
chosen_attention_mask=inputs["attention_mask"],
|
||||
logits=outputs_pos.logits,
|
||||
)
|
||||
neg_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
||||
chosen_inputs=inputs["rejected_input_ids"],
|
||||
chosen_attention_mask=inputs["rejected_attention_mask"],
|
||||
logits=outputs_neg.logits,
|
||||
)
|
||||
|
||||
# Calculate log odds
|
||||
log_odds = (pos_prob - neg_prob) - (
|
||||
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
||||
)
|
||||
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
||||
ratio = torch.log(sig_ratio)
|
||||
|
||||
# Calculate the Final Loss
|
||||
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
||||
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
||||
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
||||
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
return (loss, outputs_pos) if return_outputs else loss
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
@@ -527,6 +638,28 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -903,6 +1036,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_arguments_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
@@ -1070,6 +1208,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
|
||||
20
src/axolotl/prompt_strategies/base.py
Normal file
20
src/axolotl/prompt_strategies/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
module for base dataset transform strategies
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg, module_base=None, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
@@ -1,20 +1,8 @@
|
||||
"""
|
||||
module for DPO style dataset transform strategies
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from ..base import load as load_base
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
|
||||
|
||||
9
src/axolotl/prompt_strategies/orpo/__init__.py
Normal file
9
src/axolotl/prompt_strategies/orpo/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
module for ORPO style dataset transform strategies
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||
187
src/axolotl/prompt_strategies/orpo/chat_template.py
Normal file
187
src/axolotl/prompt_strategies/orpo/chat_template.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""chatml prompt tokenization strategy for ORPO"""
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""message/turn"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
label: Optional[bool] = None
|
||||
|
||||
|
||||
class MessageList(BaseModel):
|
||||
"""conversation"""
|
||||
|
||||
messages: List[Message]
|
||||
|
||||
|
||||
def load(
|
||||
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
"""
|
||||
|
||||
chat_template = chat_templates("chatml")
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return ORPOTokenizingStrategy(
|
||||
ORPOPrompter(chat_template, tokenizer),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
dataset_parser=ORPODatasetParsingStrategy(),
|
||||
)
|
||||
|
||||
|
||||
class ORPODatasetParsingStrategy:
|
||||
"""Strategy to parse chosen rejected dataset into messagelist"""
|
||||
|
||||
def get_chosen_conversation_thread(self, prompt) -> MessageList:
|
||||
"""Dataset structure mappings"""
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["chosen"][1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_rejected_conversation_thread(self, prompt) -> MessageList:
|
||||
"""Dataset structure mappings"""
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["rejected"][1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
|
||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
rejected_input_ids
|
||||
input_ids
|
||||
rejected_attention_mask
|
||||
attention_mask
|
||||
rejected_labels
|
||||
labels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dataset_parser=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_parser = dataset_parser
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pass the rejected prompt/row to the Prompter to get the formatted prompt
|
||||
prompt_len = 0
|
||||
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
|
||||
prompt
|
||||
)
|
||||
input_ids = []
|
||||
labels = []
|
||||
for _, (part, label) in enumerate(
|
||||
self.prompter.build_prompt(rejected_message_list)
|
||||
):
|
||||
if not part:
|
||||
continue
|
||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
||||
prev_idx = len(input_ids)
|
||||
input_ids += _input_ids[prev_idx:]
|
||||
if label:
|
||||
labels += input_ids[prev_idx:]
|
||||
else:
|
||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||
prompt_len = len(input_ids)
|
||||
# remap the input_ids, attention_mask and labels
|
||||
rejected_input_ids = input_ids
|
||||
rejected_labels = labels
|
||||
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
||||
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
|
||||
input_ids = []
|
||||
labels = []
|
||||
for _, (part, label) in enumerate(
|
||||
self.prompter.build_prompt(chosen_message_list)
|
||||
):
|
||||
if not part:
|
||||
continue
|
||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
||||
prev_idx = len(input_ids)
|
||||
input_ids += _input_ids[prev_idx:]
|
||||
if label:
|
||||
labels += input_ids[prev_idx:]
|
||||
else:
|
||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||
|
||||
return {
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_labels": rejected_labels,
|
||||
"rejected_attention_mask": [1] * len(rejected_labels),
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(labels),
|
||||
"prompt_attention_mask": [1] * prompt_len
|
||||
+ [0] * (len(labels) - prompt_len),
|
||||
}
|
||||
|
||||
|
||||
class ORPOPrompter(Prompter):
|
||||
"""Single Turn prompter for ORPO"""
|
||||
|
||||
def __init__(self, chat_template, tokenizer):
|
||||
self.chat_template = chat_template
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
message_list: MessageList,
|
||||
) -> Generator[Tuple[str, bool], None, None]:
|
||||
conversation = []
|
||||
for message in message_list.messages:
|
||||
conversation.append(message.model_dump())
|
||||
if message.role == "system":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), False
|
||||
if message.role == "user":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), False
|
||||
if message.role == "assistant":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), True
|
||||
@@ -85,7 +85,7 @@ def train(
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl:
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
|
||||
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
}
|
||||
|
||||
|
||||
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
|
||||
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].conversation = "chatml"
|
||||
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
||||
LOG.info(
|
||||
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].chat_template = "chatml"
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
|
||||
@@ -124,6 +124,7 @@ class RLType(str, Enum):
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
@@ -431,6 +432,8 @@ class AxolotlInputConfig(
|
||||
dataloader_prefetch_factor: Optional[int] = None
|
||||
dataloader_drop_last: Optional[bool] = None
|
||||
|
||||
remove_unused_columns: Optional[bool] = None
|
||||
|
||||
push_dataset_to_hub: Optional[str] = None
|
||||
hf_use_auth_token: Optional[bool] = None
|
||||
|
||||
@@ -515,6 +518,8 @@ class AxolotlInputConfig(
|
||||
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
orpo_alpha: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
] = None
|
||||
|
||||
@@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
@@ -99,7 +99,7 @@ def _invert_ranges(
|
||||
|
||||
|
||||
def _merge_ranges(
|
||||
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
||||
given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Merges overlapping ranges and sorts the given ranges.
|
||||
@@ -194,7 +194,9 @@ class LayerNamePattern:
|
||||
"""
|
||||
return self.name_regex.match(name) is not None
|
||||
|
||||
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
||||
def _parse_pattern(
|
||||
self, pattern: str
|
||||
) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
|
||||
"""
|
||||
Extracts the range pattern from the given pattern.
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
@@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
)
|
||||
|
||||
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
def test_orpo_integration(self):
|
||||
strat = load(
|
||||
self.tokenizer,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert "prompt_attention_mask" in res
|
||||
|
||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["input_ids"]) == len(res["labels"])
|
||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||
|
||||
assert res["rejected_labels"][0] == -100
|
||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
||||
|
||||
assert res["labels"][0] == -100
|
||||
assert res["input_ids"][-1] == res["labels"][-1]
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user