diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fbe23430d..ca695846f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -828,6 +828,7 @@ class ModelLoader: if self.cfg.tensor_parallel == "auto": rank = int(os.environ.get("LOCAL_RANK", 0)) + os.environ["RANK"] = str(rank) device = torch.device(f"cuda:{rank}") torch.distributed.init_process_group("nccl", device_id=device)