handle none checks

This commit is contained in:
Wing Lian
2025-07-22 21:21:45 -04:00
parent 9a2da4d9f0
commit cca207eec4

View File

@@ -396,13 +396,13 @@ class ModelLoader:
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
dp_replicate_size = get_world_size()
pc_kwargs = {}
if self.cfg.dp_shard_size > 1:
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
if self.cfg.tensor_parallel_size > 1:
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
if self.cfg.context_parallel_size > 1:
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
if dp_replicate_size > 1: