diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index a1cad3531..4e440c8a6 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -396,13 +396,13 @@ class ModelLoader: """Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator""" dp_replicate_size = get_world_size() pc_kwargs = {} - if self.cfg.dp_shard_size > 1: + if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1: pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size - if self.cfg.tensor_parallel_size > 1: + if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1: pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size - if self.cfg.context_parallel_size > 1: + if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1: pc_kwargs["cp_size"] = self.cfg.context_parallel_size dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size if dp_replicate_size > 1: