diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a3f90bca5..a32a6886d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -827,6 +827,9 @@ class ModelLoader: _ = _configure_zero3_memory_efficient_loading() if self.cfg.tensor_parallel == "auto": + from accelerate import Accelerator + + Accelerator() rank = int(os.environ.get("LOCAL_RANK", 0)) os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")