set os environ RANK

This commit is contained in:
bursteratom
2024-12-11 11:40:20 -05:00
parent b17b1aada7
commit b5f9dd44f2

View File

@@ -828,6 +828,7 @@ class ModelLoader:
if self.cfg.tensor_parallel == "auto":
rank = int(os.environ.get("LOCAL_RANK", 0))
os.environ["RANK"] = str(rank)
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)