initialise process group for tp

This commit is contained in:
bursteratom
2024-12-11 11:35:16 -05:00
parent acde081321
commit 85381b6b15

View File

@@ -826,6 +826,11 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.tensor_parallel == "auto":
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(