DPO transformers v0.29 fixes (#3560) [skip ci]
* Deperecate dpo_norm_loss * Rename chosen/rejected_input_ids to chosen/rejected_ids to match TRL https://github.com/huggingface/trl/pull/5179 * Remove deprecated rpo_alpha * Remove dead_code tokenize_row * Add _tokenize override to prevent double bos token on Llama DPO * Fix DPO loss type now list not string * Linting fix * PR fixes * update _tokenize override for DPO for multimodal
This commit is contained in:
@@ -67,55 +67,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"rpo_alpha": 0.5,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -223,18 +223,18 @@ class OrpoTokenizationTest:
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(ds[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert "prompt_attention_mask" in res
|
||||
|
||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["rejected_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["input_ids"]) == len(res["labels"])
|
||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||
|
||||
assert res["rejected_labels"][0] == -100
|
||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
||||
assert res["rejected_ids"][-1] == res["rejected_labels"][-1]
|
||||
|
||||
assert res["labels"][0] == -100
|
||||
assert res["input_ids"][-1] == res["labels"][-1]
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset, remove_double_bos_token
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -541,5 +541,33 @@ class TestHandleLongSeqInDataset(unittest.TestCase):
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
|
||||
|
||||
class TestRemoveDoubleBOSToken(unittest.TestCase):
|
||||
def test_no_remove_bos_token(self):
|
||||
input_ids = [0, 1, 2]
|
||||
labels = [1, 2, 3]
|
||||
|
||||
example = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
example = remove_double_bos_token(example, 0)
|
||||
assert example["input_ids"] == input_ids
|
||||
assert example["labels"] == labels
|
||||
|
||||
def test_remove_bos_token(self):
|
||||
input_ids = [0, 0, 1]
|
||||
labels = [0, 1, 2]
|
||||
|
||||
example = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
example = remove_double_bos_token(example, 0)
|
||||
assert example["input_ids"] == [0, 1]
|
||||
assert example["labels"] == [1, 2]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user