limit num_proc when saving datasets to disk (#2948) [skip ci]
* limit num_proc when saving datasets to disk * enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save * update fixtures with dataset processes since that should never be NoneType * improve reusability for tests
This commit is contained in:
@@ -82,6 +82,7 @@ def fixture_base_cfg():
|
||||
"ddp_timeout": 1800,
|
||||
"ddp_bucket_cap_mb": 25,
|
||||
"ddp_broadcast_buffers": False,
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -440,6 +441,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
||||
cfg["dataset_processes"] = 4
|
||||
|
||||
if cfg_string == "grpo_cfg":
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
|
||||
@@ -141,6 +141,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -179,6 +180,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -217,6 +219,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -249,6 +252,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -281,6 +285,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -365,6 +370,7 @@ class TestDatasetPreparation:
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -466,6 +472,7 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -210,6 +210,7 @@ class TestDeduplicateRLDataset:
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
yield fixture
|
||||
|
||||
@@ -99,6 +99,7 @@ class TestPacking(unittest.TestCase):
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
|
||||
Reference in New Issue
Block a user