test
This commit is contained in:
@@ -27,7 +27,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -42,6 +42,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "dpo",
|
||||
"rpo_alpha": 0.5,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
@@ -71,7 +72,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -86,7 +87,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "dpo",
|
||||
"rpo_alpha": 0.5,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
|
||||
Reference in New Issue
Block a user