This commit is contained in:
sunny
2024-11-05 12:38:33 -05:00
parent d9b65f69fb
commit 41d10278bf

View File

@@ -27,7 +27,7 @@ class TestDPOLlamaLora(unittest.TestCase):
""" """
@with_temp_dir @with_temp_dir
def test_dpo_nll_lora(self, temp_dir): def test_dpo_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -42,7 +42,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_target_linear": True, "lora_target_linear": True,
"special_tokens": {}, "special_tokens": {},
"rl": "dpo", "rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [ "datasets": [
{ {
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
@@ -72,7 +71,7 @@ class TestDPOLlamaLora(unittest.TestCase):
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_dpo_lora(self, temp_dir): def test_dpo_nll_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -87,6 +86,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"lora_target_linear": True, "lora_target_linear": True,
"special_tokens": {}, "special_tokens": {},
"rl": "dpo", "rl": "dpo",
"rpo_alpha": 0.5,
"datasets": [ "datasets": [
{ {
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",