Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -171,3 +171,44 @@ class TestModelsUtils:
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
|
||||
[
|
||||
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
|
||||
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
|
||||
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
|
||||
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_parallel_config_kwargs(
|
||||
self,
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = (
|
||||
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
)
|
||||
|
||||
if expected[0] > 1:
|
||||
assert res["tp_size"] == expected[0]
|
||||
if expected[1] > 1:
|
||||
assert res["cp_size"] == expected[1]
|
||||
if expected[2] > 1:
|
||||
assert res["dp_shard_size"] == expected[2]
|
||||
if expected[3] > 1:
|
||||
assert res["dp_replicate_size"] == expected[3]
|
||||
|
||||
Reference in New Issue
Block a user