From c760d2b815dd96e81de1084944a2411c3f91bd13 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Thu, 12 Dec 2024 12:29:35 -0500 Subject: [PATCH] test accelerator --- src/axolotl/utils/models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a3f90bca5..a32a6886d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -827,6 +827,9 @@ class ModelLoader: _ = _configure_zero3_memory_efficient_loading() if self.cfg.tensor_parallel == "auto": + from accelerate import Accelerator + + Accelerator() rank = int(os.environ.get("LOCAL_RANK", 0)) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")