test
This commit is contained in:
@@ -27,7 +27,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_dpo_nll_lora(self, temp_dir):
|
def test_dpo_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -42,7 +42,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"rpo_alpha": 0.5,
|
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
@@ -72,7 +71,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_dpo_lora(self, temp_dir):
|
def test_dpo_nll_lora(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -87,6 +86,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
|
"rpo_alpha": 0.5,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
|
|||||||
Reference in New Issue
Block a user