handle none checks
This commit is contained in:
@@ -396,13 +396,13 @@ class ModelLoader:
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
dp_replicate_size = get_world_size()
|
||||
pc_kwargs = {}
|
||||
if self.cfg.dp_shard_size > 1:
|
||||
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
||||
if self.cfg.context_parallel_size > 1:
|
||||
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
||||
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
||||
if dp_replicate_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user