diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8f477ff16..d44a4b4b6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -826,6 +826,11 @@ class ModelLoader: _ = _configure_zero3_memory_efficient_loading() + if self.cfg.tensor_parallel == "auto": + rank = int(os.environ["RANK"]) + device = torch.device(f"cuda:{rank}") + torch.distributed.init_process_group("nccl", device_id=device) + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained(