DPO transformers v0.29 fixes (#3560) [skip ci]
* Deperecate dpo_norm_loss * Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179 * Remove deprecated rpo_alpha * Remove dead_code tokenize_row * Add _tokenize override to prevent double bos token on Llama DPO * Fix DPO loss type now list not string * Linting fix * PR fixes * update _tokenize override for DPO for multimodal
This commit is contained in:
@@ -127,9 +127,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
if self.cfg.rpo_alpha is not None:
|
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
|
||||||
|
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
|
|
||||||
|
|||||||
@@ -405,15 +405,13 @@ class AxolotlTrainer(
|
|||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
concatenated_batch = {}
|
concatenated_batch = {}
|
||||||
|
|
||||||
max_length = max(
|
max_length = max(inputs["input_ids"].shape[1], inputs["rejected_ids"].shape[1])
|
||||||
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
|
||||||
)
|
|
||||||
# Concatenate positive and negative inputs
|
# Concatenate positive and negative inputs
|
||||||
concatenated_batch["input_ids"] = pad_to_length(
|
concatenated_batch["input_ids"] = pad_to_length(
|
||||||
inputs["input_ids"], max_length, pad_token
|
inputs["input_ids"], max_length, pad_token
|
||||||
)
|
)
|
||||||
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
concatenated_batch["rejected_ids"] = pad_to_length(
|
||||||
inputs["rejected_input_ids"], max_length, pad_token
|
inputs["rejected_ids"], max_length, pad_token
|
||||||
)
|
)
|
||||||
concatenated_batch["labels"] = pad_to_length(
|
concatenated_batch["labels"] = pad_to_length(
|
||||||
inputs["labels"], max_length, label_pad_token
|
inputs["labels"], max_length, label_pad_token
|
||||||
@@ -432,7 +430,7 @@ class AxolotlTrainer(
|
|||||||
).to(device=device)
|
).to(device=device)
|
||||||
|
|
||||||
input_ids = torch.cat(
|
input_ids = torch.cat(
|
||||||
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
[concatenated_batch["input_ids"], concatenated_batch["rejected_ids"]],
|
||||||
dim=0,
|
dim=0,
|
||||||
).to(device=device)
|
).to(device=device)
|
||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class DPOStrategy:
|
|||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
training_args_kwargs = {}
|
training_args_kwargs = {}
|
||||||
if cfg.rl is RLType.IPO:
|
if cfg.rl is RLType.IPO:
|
||||||
training_args_kwargs["loss_type"] = "ipo"
|
training_args_kwargs["loss_type"] = ["ipo"]
|
||||||
# Label smoothing is not compatible with IPO
|
# Label smoothing is not compatible with IPO
|
||||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
@@ -30,8 +30,6 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||||
if cfg.dpo_norm_loss is not None:
|
|
||||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
|
||||||
if cfg.dpo_use_liger_kernel is not None:
|
if cfg.dpo_use_liger_kernel is not None:
|
||||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|||||||
@@ -2,8 +2,7 @@
|
|||||||
Axolotl specific DPO args
|
Axolotl specific DPO args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
|
|
||||||
@@ -15,6 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
|||||||
"""
|
"""
|
||||||
DPO config for DPO training
|
DPO config for DPO training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dpo_norm_loss: bool | None = False
|
|
||||||
rpo_alpha: Optional[float] = field(default=None)
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
@@ -18,6 +19,7 @@ from axolotl.core.trainers.utils import (
|
|||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.data.utils import remove_double_bos_token
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(
|
class AxolotlDPOTrainer(
|
||||||
@@ -53,36 +55,31 @@ class AxolotlDPOTrainer(
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
def _tokenize(
|
||||||
def tokenize_row(
|
self,
|
||||||
features,
|
processing_class: PreTrainedTokenizerBase | ProcessorMixin,
|
||||||
processing_class,
|
input: str | list,
|
||||||
max_prompt_length: int | None = None,
|
**kwargs,
|
||||||
max_completion_length: int | None = None,
|
) -> dict[str, list]:
|
||||||
add_special_tokens: bool = True,
|
"""
|
||||||
is_chat: bool = False,
|
Override TRL's tokenization in DPO trainer to fix double bos_token bug (eg. llama).
|
||||||
) -> Dict:
|
"""
|
||||||
res = DPOTrainer.tokenize_row(
|
result = super()._tokenize(
|
||||||
features,
|
processing_class=processing_class, input=input, **kwargs
|
||||||
processing_class,
|
|
||||||
max_prompt_length=max_prompt_length,
|
|
||||||
max_completion_length=max_completion_length,
|
|
||||||
add_special_tokens=add_special_tokens,
|
|
||||||
is_chat=is_chat,
|
|
||||||
)
|
)
|
||||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
|
||||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
|
||||||
for key in res.keys():
|
|
||||||
res[key] = res[key][1:]
|
|
||||||
|
|
||||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
# Handle multimodal models
|
||||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
tokenizer = (
|
||||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
getattr(processing_class, "tokenizer", None)
|
||||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
if isinstance(processing_class, ProcessorMixin)
|
||||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
else processing_class
|
||||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
)
|
||||||
|
|
||||||
return res
|
bos_token_id = getattr(tokenizer, "bos_token_id", None) if tokenizer else None
|
||||||
|
if bos_token_id is not None:
|
||||||
|
result = remove_double_bos_token(result, bos_token_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self,
|
||||||
@@ -94,20 +91,3 @@ class AxolotlDPOTrainer(
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def concatenated_forward(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
is_ref_model: bool = False,
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
if self.args.dpo_norm_loss:
|
|
||||||
# fmt: off
|
|
||||||
loss_type: list[str] = self.loss_type # type: ignore[has-type]
|
|
||||||
# fmt: on
|
|
||||||
# concatenated_forward handles avg token logprob for ipo case already
|
|
||||||
self.loss_type = ["ipo"]
|
|
||||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
|
||||||
self.loss_type = loss_type
|
|
||||||
return res
|
|
||||||
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
|
||||||
|
|||||||
@@ -71,10 +71,10 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"chosen_input_ids": chosen_tokenized["input_ids"],
|
"chosen_ids": chosen_tokenized["input_ids"],
|
||||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||||
"labels_chosen": 1.0,
|
"labels_chosen": 1.0,
|
||||||
"rejected_input_ids": rejected_tokenized["input_ids"],
|
"rejected_ids": rejected_tokenized["input_ids"],
|
||||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||||
"labels_rejected": 0.0,
|
"labels_rejected": 0.0,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class ORPODatasetParsingStrategy:
|
|||||||
|
|
||||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
rejected_input_ids
|
rejected_ids
|
||||||
input_ids
|
input_ids
|
||||||
rejected_attention_mask
|
rejected_attention_mask
|
||||||
attention_mask
|
attention_mask
|
||||||
@@ -169,7 +169,7 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||||
prompt_len = len(input_ids)
|
prompt_len = len(input_ids)
|
||||||
# remap the input_ids, attention_mask and labels
|
# remap the input_ids, attention_mask and labels
|
||||||
rejected_input_ids = input_ids
|
rejected_ids = input_ids
|
||||||
rejected_labels = labels
|
rejected_labels = labels
|
||||||
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
||||||
chosen_message_list: MessageList = (
|
chosen_message_list: MessageList = (
|
||||||
@@ -191,7 +191,7 @@ class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"rejected_input_ids": rejected_input_ids,
|
"rejected_ids": rejected_ids,
|
||||||
"rejected_labels": rejected_labels,
|
"rejected_labels": rejected_labels,
|
||||||
"rejected_attention_mask": [1] * len(rejected_labels),
|
"rejected_attention_mask": [1] * len(rejected_labels),
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
|||||||
@@ -349,3 +349,14 @@ def handle_long_seq_in_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def remove_double_bos_token(example: dict[str, list], bos_token_id: int):
|
||||||
|
"""Remove double bos tokens that may occur when retokenizing preprocessed data
|
||||||
|
for tokenizers and chat templates that have a bos_token - eg. DPO + Llama.
|
||||||
|
"""
|
||||||
|
input_ids = example["input_ids"]
|
||||||
|
if len(input_ids) >= 2 and input_ids[0] == input_ids[1] == bos_token_id:
|
||||||
|
for key in example:
|
||||||
|
example[key] = example[key][1:]
|
||||||
|
return example
|
||||||
|
|||||||
@@ -294,7 +294,6 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
|
||||||
|
|
||||||
dpo_use_liger_kernel: bool | None = Field(
|
dpo_use_liger_kernel: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -1111,12 +1110,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping."
|
"description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
rpo_alpha: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Weighting of NLL term in loss from RPO paper"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
simpo_gamma: float | None = Field(
|
simpo_gamma: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Target reward margin for the SimPO loss"},
|
json_schema_extra={"description": "Target reward margin for the SimPO loss"},
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ class DeprecatedParameters(BaseModel):
|
|||||||
eval_max_new_tokens: int | None = None
|
eval_max_new_tokens: int | None = None
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
dpo_generate_during_eval: bool | None = None
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
dpo_norm_loss: bool | None = None
|
||||||
|
rpo_alpha: float | None = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -100,6 +102,26 @@ class DeprecatedParameters(BaseModel):
|
|||||||
)
|
)
|
||||||
return dpo_generate_during_eval
|
return dpo_generate_during_eval
|
||||||
|
|
||||||
|
@field_validator("dpo_norm_loss")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_norm_loss(cls, dpo_norm_loss):
|
||||||
|
if dpo_norm_loss is not None:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`dpo_norm_loss` is no longer supported, "
|
||||||
|
"due to breaking changes in TRL >= 0.29.0"
|
||||||
|
)
|
||||||
|
return dpo_norm_loss
|
||||||
|
|
||||||
|
@field_validator("rpo_alpha")
|
||||||
|
@classmethod
|
||||||
|
def validate_rpo_alpha(cls, rpo_alpha):
|
||||||
|
if rpo_alpha is not None:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`rpo_alpha` has been deprecated in TRL >= 0.29.0, "
|
||||||
|
"and now requires passing multiple loss types, which is not yet supported by Axolotl."
|
||||||
|
) # TODO: change this warning once multiple dpo loss types are supported.
|
||||||
|
return rpo_alpha
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""Parameters that have been remapped to other names"""
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|||||||
@@ -67,55 +67,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_nll_lora(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 64,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.1,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"rl": "dpo",
|
|
||||||
"rpo_alpha": 0.5,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
|
||||||
"type": "chatml.ultra",
|
|
||||||
"split": "train",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "paged_adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 20,
|
|
||||||
"save_steps": 10,
|
|
||||||
"warmup_steps": 5,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_dpo_use_weighting(self, temp_dir):
|
def test_dpo_use_weighting(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -223,18 +223,18 @@ class OrpoTokenizationTest:
|
|||||||
DictDefault({"chat_template": "chatml"}),
|
DictDefault({"chat_template": "chatml"}),
|
||||||
)
|
)
|
||||||
res = strat.tokenize_prompt(ds[0])
|
res = strat.tokenize_prompt(ds[0])
|
||||||
assert "rejected_input_ids" in res
|
assert "rejected_ids" in res
|
||||||
assert "rejected_labels" in res
|
assert "rejected_labels" in res
|
||||||
assert "input_ids" in res
|
assert "input_ids" in res
|
||||||
assert "labels" in res
|
assert "labels" in res
|
||||||
assert "prompt_attention_mask" in res
|
assert "prompt_attention_mask" in res
|
||||||
|
|
||||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
assert len(res["rejected_ids"]) == len(res["rejected_labels"])
|
||||||
assert len(res["input_ids"]) == len(res["labels"])
|
assert len(res["input_ids"]) == len(res["labels"])
|
||||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||||
|
|
||||||
assert res["rejected_labels"][0] == -100
|
assert res["rejected_labels"][0] == -100
|
||||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
assert res["rejected_ids"][-1] == res["rejected_labels"][-1]
|
||||||
|
|
||||||
assert res["labels"][0] == -100
|
assert res["labels"][0] == -100
|
||||||
assert res["input_ids"][-1] == res["labels"][-1]
|
assert res["input_ids"][-1] == res["labels"][-1]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -541,5 +541,33 @@ class TestHandleLongSeqInDataset(unittest.TestCase):
|
|||||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveDoubleBOSToken(unittest.TestCase):
|
||||||
|
def test_no_remove_bos_token(self):
|
||||||
|
input_ids = [0, 1, 2]
|
||||||
|
labels = [1, 2, 3]
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
example = remove_double_bos_token(example, 0)
|
||||||
|
assert example["input_ids"] == input_ids
|
||||||
|
assert example["labels"] == labels
|
||||||
|
|
||||||
|
def test_remove_bos_token(self):
|
||||||
|
input_ids = [0, 0, 1]
|
||||||
|
labels = [0, 1, 2]
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
example = remove_double_bos_token(example, 0)
|
||||||
|
assert example["input_ids"] == [0, 1]
|
||||||
|
assert example["labels"] == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user