test accelerator

This commit is contained in:
bursteratom
2024-12-12 12:29:35 -05:00
parent 2014f58181
commit c760d2b815

View File

@@ -827,6 +827,9 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.tensor_parallel == "auto":
from accelerate import Accelerator
Accelerator()
rank = int(os.environ.get("LOCAL_RANK", 0))
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")