diff --git a/docs/rlhf.md b/docs/rlhf.md index 9f5ba05fd..4f71184fc 100644 --- a/docs/rlhf.md +++ b/docs/rlhf.md @@ -34,6 +34,21 @@ datasets: rl: ipo ``` +#### ORPO + +Paper: https://arxiv.org/abs/2403.07691 + +```yaml +rl: orpo +orpo_alpha: 0.1 +remove_unused_columns: false + +chat_template: chatml +datasets: + - path: argilla/ultrafeedback-binarized-preferences-cleaned + type: orpo.chat_template +``` + #### Using local dataset files ```yaml datasets: diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 89ab023e5..a1a01d59d 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): LOG.warning(msg) parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH - if parsed_cfg.rl: + if parsed_cfg.rl and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 05fd63ae8..7e004567a 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() - if cfg.rl: + if cfg.rl and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d11f0c653..42180f32b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -11,10 +11,11 @@ import math import os import sys from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import List, Optional, Type, Union +from typing import Dict, List, Literal, Optional, Type, Union import torch import transformers @@ -200,6 +201,9 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "whether this is a qlora training"}, ) + orpo_alpha: Optional[float] = field( + default=None, + ) class AxolotlTrainer(Trainer): @@ -223,6 +227,9 @@ class AxolotlTrainer(Trainer): self.eval_data_collator = eval_data_collator super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") def create_optimizer(self): if self.args.loraplus_lr_ratio is None: @@ -465,8 +472,112 @@ class AxolotlTrainer(Trainer): # outputs = model(**inputs) # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) # return (loss, outputs) if return_outputs else loss + if self.args.orpo_alpha: + return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs) + def orpo_compute_custom_loss(self, logits, labels): + logits = logits.contiguous() + loss = 0.0 + + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( + dim=-1 + ) + + return loss + + def orpo_compute_logps( + self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits + ): + # Get the shape of chosen_attention_mask[:, :-1] + chosen_shape = chosen_attention_mask[:, :-1].shape + + # Calculate the padding size + pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) + + # Pad prompt_attention_mask with zeros to match the desired shape + prompt_attention_mask_padded = torch.nn.functional.pad( + prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 + ) + + # Perform the subtraction operation + mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded + + per_token_logps = torch.gather( + logits[:, :-1, :].log_softmax(-1), + dim=2, + index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), + ).squeeze(2) + return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to( + dtype=torch.float64 + ) / mask.sum(dim=1).to(dtype=torch.float64) + + def orpo_compute_loss(self, model, inputs, return_outputs=False): + outputs_neg = model( + **{ + "input_ids": inputs["rejected_input_ids"], + "attention_mask": inputs["rejected_attention_mask"], + "labels": inputs["rejected_labels"], + }, + output_hidden_states=True, + ) + outputs_pos = model( + **{ + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "labels": inputs["labels"], + }, + output_hidden_states=True, + ) + + # Calculate NLL loss + pos_loss = self.orpo_compute_custom_loss( + logits=outputs_pos.logits, labels=inputs["input_ids"] + ) + + # Calculate Log Probability + pos_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["input_ids"], + chosen_attention_mask=inputs["attention_mask"], + logits=outputs_pos.logits, + ) + neg_prob = self.orpo_compute_logps( + prompt_attention_mask=inputs["prompt_attention_mask"], + chosen_inputs=inputs["rejected_input_ids"], + chosen_attention_mask=inputs["rejected_attention_mask"], + logits=outputs_neg.logits, + ) + + # Calculate log odds + log_odds = (pos_prob - neg_prob) - ( + torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) + ) + sig_ratio = torch.nn.functional.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + + # Calculate the Final Loss + loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( + dtype=torch.bfloat16 + ) + + metrics = {} + metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() + metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() + metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() + metrics["log_odds"] = torch.mean(log_odds).cpu().item() + self.store_metrics(metrics, train_eval="train") + + return (loss, outputs_pos) if return_outputs else loss + @wraps(Trainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ @@ -527,6 +638,28 @@ class AxolotlTrainer(Trainer): return res + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + class AxolotlMambaTrainer(AxolotlTrainer): """ @@ -903,6 +1036,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: training_arguments_kwargs["dataloader_drop_last"] = True + if self.cfg.remove_unused_columns is not None: + training_arguments_kwargs[ + "remove_unused_columns" + ] = self.cfg.remove_unused_columns + if not self.cfg.test_datasets and self.cfg.val_set_size == 0: # no eval set, so don't eval training_arguments_kwargs["evaluation_strategy"] = "no" @@ -1070,6 +1208,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) + if self.cfg.rl == "orpo": + training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha + if self.cfg.neftune_noise_alpha is not None: training_arguments_kwargs[ "neftune_noise_alpha" diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py new file mode 100644 index 000000000..fce2aba14 --- /dev/null +++ b/src/axolotl/prompt_strategies/base.py @@ -0,0 +1,20 @@ +""" +module for base dataset transform strategies +""" + +import importlib +import logging + +LOG = logging.getLogger("axolotl") + + +def load(strategy, cfg, module_base=None, **kwargs): + try: + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + mod = importlib.import_module(f".{strategy}", module_base) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except Exception: # pylint: disable=broad-exception-caught + LOG.warning(f"unable to load strategy {strategy}") + return None diff --git a/src/axolotl/prompt_strategies/dpo/__init__.py b/src/axolotl/prompt_strategies/dpo/__init__.py index 8bd430f91..1a149f452 100644 --- a/src/axolotl/prompt_strategies/dpo/__init__.py +++ b/src/axolotl/prompt_strategies/dpo/__init__.py @@ -1,20 +1,8 @@ """ module for DPO style dataset transform strategies """ +from functools import partial -import importlib -import logging +from ..base import load as load_base -LOG = logging.getLogger("axolotl") - - -def load(strategy, cfg, **kwargs): - try: - load_fn = strategy.split(".")[-1] - strategy = ".".join(strategy.split(".")[:-1]) - mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo") - func = getattr(mod, load_fn) - return func(cfg, **kwargs) - except Exception: # pylint: disable=broad-exception-caught - LOG.warning(f"unable to load strategy {strategy}") - return None +load = partial(load_base, module="axolotl.prompt_strategies.dpo") diff --git a/src/axolotl/prompt_strategies/orpo/__init__.py b/src/axolotl/prompt_strategies/orpo/__init__.py new file mode 100644 index 000000000..3a961fcc9 --- /dev/null +++ b/src/axolotl/prompt_strategies/orpo/__init__.py @@ -0,0 +1,9 @@ +""" +module for ORPO style dataset transform strategies +""" + +from functools import partial + +from ..base import load as load_base + +load = partial(load_base, module="axolotl.prompt_strategies.orpo") diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py new file mode 100644 index 000000000..fb39bcf8f --- /dev/null +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -0,0 +1,187 @@ +"""chatml prompt tokenization strategy for ORPO""" +from typing import Any, Dict, Generator, List, Optional, Tuple + +from pydantic import BaseModel + +from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy +from axolotl.prompters import Prompter +from axolotl.utils.chat_templates import chat_templates + + +class Message(BaseModel): + """message/turn""" + + role: str + content: str + label: Optional[bool] = None + + +class MessageList(BaseModel): + """conversation""" + + messages: List[Message] + + +def load( + tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + """ + + chat_template = chat_templates("chatml") + if ds_cfg and "chat_template" in ds_cfg: + chat_template = ds_cfg["chat_template"] + try: + chat_template = chat_templates(chat_template) + except ValueError: + pass + + return ORPOTokenizingStrategy( + ORPOPrompter(chat_template, tokenizer), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + dataset_parser=ORPODatasetParsingStrategy(), + ) + + +class ORPODatasetParsingStrategy: + """Strategy to parse chosen rejected dataset into messagelist""" + + def get_chosen_conversation_thread(self, prompt) -> MessageList: + """Dataset structure mappings""" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message( + role="assistant", content=prompt["chosen"][1]["content"], label=True + ) + ) + return MessageList(messages=messages) + + def get_rejected_conversation_thread(self, prompt) -> MessageList: + """Dataset structure mappings""" + + messages: List[Message] = [] + if system := prompt.get("system", None): + messages.append(Message(role="system", content=system, label=False)) + messages.append(Message(role="user", content=prompt["prompt"], label=False)) + messages.append( + Message( + role="assistant", content=prompt["rejected"][1]["content"], label=True + ) + ) + return MessageList(messages=messages) + + +class ORPOTokenizingStrategy(PromptTokenizingStrategy): + """ + rejected_input_ids + input_ids + rejected_attention_mask + attention_mask + rejected_labels + labels + """ + + def __init__( + self, + *args, + dataset_parser=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.dataset_parser = dataset_parser + + def tokenize_prompt(self, prompt): + # pass the rejected prompt/row to the Prompter to get the formatted prompt + prompt_len = 0 + rejected_message_list = self.dataset_parser.get_rejected_conversation_thread( + prompt + ) + input_ids = [] + labels = [] + for _, (part, label) in enumerate( + self.prompter.build_prompt(rejected_message_list) + ): + if not part: + continue + _input_ids = self.tokenizer.encode(part, add_special_tokens=False) + prev_idx = len(input_ids) + input_ids += _input_ids[prev_idx:] + if label: + labels += input_ids[prev_idx:] + else: + labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) + prompt_len = len(input_ids) + # remap the input_ids, attention_mask and labels + rejected_input_ids = input_ids + rejected_labels = labels + # pass the chosen prompt/row to the Prompter to get the formatted prompt + chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt) + input_ids = [] + labels = [] + for _, (part, label) in enumerate( + self.prompter.build_prompt(chosen_message_list) + ): + if not part: + continue + _input_ids = self.tokenizer.encode(part, add_special_tokens=False) + prev_idx = len(input_ids) + input_ids += _input_ids[prev_idx:] + if label: + labels += input_ids[prev_idx:] + else: + labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx) + + return { + "rejected_input_ids": rejected_input_ids, + "rejected_labels": rejected_labels, + "rejected_attention_mask": [1] * len(rejected_labels), + "input_ids": input_ids, + "labels": labels, + "attention_mask": [1] * len(labels), + "prompt_attention_mask": [1] * prompt_len + + [0] * (len(labels) - prompt_len), + } + + +class ORPOPrompter(Prompter): + """Single Turn prompter for ORPO""" + + def __init__(self, chat_template, tokenizer): + self.chat_template = chat_template + self.tokenizer = tokenizer + + def build_prompt( + self, + message_list: MessageList, + ) -> Generator[Tuple[str, bool], None, None]: + conversation = [] + for message in message_list.messages: + conversation.append(message.model_dump()) + if message.role == "system": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), False + if message.role == "user": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=True, + chat_template=self.chat_template, + tokenize=False, + ), False + if message.role == "assistant": + yield self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=False, + chat_template=self.chat_template, + tokenize=False, + ), True diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d2fc75261..b6cd24672 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -85,7 +85,7 @@ def train( model.generation_config.do_sample = True model_ref = None - if cfg.rl: + if cfg.rl and cfg.rl != "orpo": if cfg.adapter and not cfg.rl_adapter_ref_model: # use built-in trl autounwrap LOG.debug("Passing model_ref: None to RL trainer") diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 1ec83536d..fd34b4ea9 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -21,7 +21,7 @@ def chat_templates(user_choice: str): templates = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. - "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", } diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 9151f288a..3e743bda9 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template" ) cfg.datasets[idx].conversation = "chatml" + if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template: + LOG.info( + f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template" + ) + cfg.datasets[idx].chat_template = "chatml" def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index dfe9a9be9..ef31c05c2 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -124,6 +124,7 @@ class RLType(str, Enum): dpo = "dpo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name + orpo = "orpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -431,6 +432,8 @@ class AxolotlInputConfig( dataloader_prefetch_factor: Optional[int] = None dataloader_drop_last: Optional[bool] = None + remove_unused_columns: Optional[bool] = None + push_dataset_to_hub: Optional[str] = None hf_use_auth_token: Optional[bool] = None @@ -515,6 +518,8 @@ class AxolotlInputConfig( neftune_noise_alpha: Optional[float] = None + orpo_alpha: Optional[float] = None + max_memory: Optional[ Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]] ] = None diff --git a/src/axolotl/utils/freeze.py b/src/axolotl/utils/freeze.py index 64b994f84..e3d0fd144 100644 --- a/src/axolotl/utils/freeze.py +++ b/src/axolotl/utils/freeze.py @@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name """ import logging import re -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Union from axolotl.utils.distributed import is_main_process @@ -99,7 +99,7 @@ def _invert_ranges( def _merge_ranges( - given_ranges: List[Tuple[int, int | None]], layer_size: int + given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int ) -> List[Tuple[int, int]]: """ Merges overlapping ranges and sorts the given ranges. @@ -194,7 +194,9 @@ class LayerNamePattern: """ return self.name_regex.match(name) is not None - def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]: + def _parse_pattern( + self, pattern: str + ) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]: """ Extracts the range pattern from the given pattern. diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 077b63b37..4e659006f 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -8,7 +8,8 @@ from pathlib import Path from typing import Optional import pytest -from transformers import AutoTokenizer, LlamaTokenizer +from datasets import load_dataset +from transformers import AddedToken, AutoTokenizer, LlamaTokenizer from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_w_system import ( @@ -19,12 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import ( Llama2ChatPrompter, LLama2ChatTokenizingStrategy, ) +from axolotl.prompt_strategies.orpo.chat_template import load from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, ) from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2 +from axolotl.utils.dict import DictDefault LOG = logging.getLogger("axolotl") @@ -446,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why ) +class OrpoTokenizationTest(unittest.TestCase): + """test case for the ORPO tokenization""" + + def setUp(self) -> None: + # pylint: disable=duplicate-code + tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + tokenizer.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + tokenizer.add_tokens( + [ + AddedToken( + "<|im_start|>", rstrip=False, lstrip=False, normalized=False + ), + ] + ) + self.tokenizer = tokenizer + self.dataset = load_dataset( + "argilla/ultrafeedback-binarized-preferences-cleaned", split="train" + ).select([0]) + + def test_orpo_integration(self): + strat = load( + self.tokenizer, + DictDefault({"train_on_inputs": False}), + DictDefault({"chat_template": "chatml"}), + ) + res = strat.tokenize_prompt(self.dataset[0]) + assert "rejected_input_ids" in res + assert "rejected_labels" in res + assert "input_ids" in res + assert "labels" in res + assert "prompt_attention_mask" in res + + assert len(res["rejected_input_ids"]) == len(res["rejected_labels"]) + assert len(res["input_ids"]) == len(res["labels"]) + assert len(res["input_ids"]) == len(res["prompt_attention_mask"]) + + assert res["rejected_labels"][0] == -100 + assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1] + + assert res["labels"][0] == -100 + assert res["input_ids"][-1] == res["labels"][-1] + + assert res["prompt_attention_mask"][0] == 1 + assert res["prompt_attention_mask"][-1] == 0 + + if __name__ == "__main__": unittest.main()