Distributed/ND-Parallel (#2977)

This commit is contained in:
salman
2025-07-31 20:25:02 +01:00
committed by GitHub
parent 7b68dfafd7
commit 294c7fe7a6
49 changed files with 712 additions and 835 deletions

View File

@@ -171,3 +171,44 @@ class TestModelsUtils:
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
[
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
],
)
def test_get_parallel_config_kwargs(
self,
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
expected,
):
res = (
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
)
if expected[0] > 1:
assert res["tp_size"] == expected[0]
if expected[1] > 1:
assert res["cp_size"] == expected[1]
if expected[2] > 1:
assert res["dp_shard_size"] == expected[2]
if expected[3] > 1:
assert res["dp_replicate_size"] == expected[3]