From 85381b6b155ba796899872b23c98443d76986e70 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Wed, 11 Dec 2024 11:35:16 -0500 Subject: [PATCH] initialise process group for tp --- src/axolotl/utils/models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8f477ff16..d44a4b4b6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -826,6 +826,11 @@ class ModelLoader: _ = _configure_zero3_memory_efficient_loading() + if self.cfg.tensor_parallel == "auto": + rank = int(os.environ["RANK"]) + device = torch.device(f"cuda:{rank}") + torch.distributed.init_process_group("nccl", device_id=device) + if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained(