From 60c98a4353f69db5e4c237f4b9fe61b96d6c27b7 Mon Sep 17 00:00:00 2001 From: bursteratom Date: Fri, 13 Dec 2024 15:44:51 -0500 Subject: [PATCH] stuff --- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/models.py | 17 ++++++----------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e65..4e34ce706 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC): if hasattr(model, "add_model_tags"): model.add_model_tags(["axolotl"]) + if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan: + os.environ["ACCELERATE_USE_TP"] = "true" + # self.model = + @property def model_ref(self): return self._model_ref diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a32a6886d..3d79f0116 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -621,7 +621,6 @@ class ModelLoader: self.model_kwargs["device_map"] = device_map self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel cur_device = get_device_type() if "mps" in str(cur_device): @@ -826,16 +825,6 @@ class ModelLoader: _ = _configure_zero3_memory_efficient_loading() - if self.cfg.tensor_parallel == "auto": - from accelerate import Accelerator - - Accelerator() - rank = int(os.environ.get("LOCAL_RANK", 0)) - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") - device = torch.device(f"cuda:{rank}") - torch.distributed.init_process_group("nccl", device_id=device) - if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config self.model = self.AutoModelLoader.from_pretrained( @@ -1198,9 +1187,15 @@ class ModelLoader: gc.collect() torch.cuda.empty_cache() + self.post_loading_set_env() + # TODO resume_from_checkpoint handling return self.model, lora_config + def post_loading_set_env(self): + if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan: + os.environ["ACCELERATE_USE_TP"] = "true" + def load_model( cfg: DictDefault,