diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1041892f5..a1cad3531 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -650,6 +650,14 @@ class ModelLoader: def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False + + if self.cfg.tensor_parallel_size > 1: + self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size + self.model_kwargs["tp_plan"] = "auto" + self.model_kwargs["device_mesh"] = PartialState().device_mesh + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] # not compatible with `tp_plan` + if self.is_fsdp_enabled: if self.cfg.fsdp_config.cpu_ram_efficient_loading: skip_move_to_device = True