handle none checks
This commit is contained in:
@@ -396,13 +396,13 @@ class ModelLoader:
|
|||||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
dp_replicate_size = get_world_size()
|
dp_replicate_size = get_world_size()
|
||||||
pc_kwargs = {}
|
pc_kwargs = {}
|
||||||
if self.cfg.dp_shard_size > 1:
|
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
|
||||||
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
||||||
if self.cfg.tensor_parallel_size > 1:
|
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
|
||||||
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
||||||
if self.cfg.context_parallel_size > 1:
|
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||||
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
||||||
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
||||||
if dp_replicate_size > 1:
|
if dp_replicate_size > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user