This commit is contained in:
bursteratom
2024-12-13 15:44:51 -05:00
parent c760d2b815
commit 60c98a4353
2 changed files with 10 additions and 11 deletions

View File

@@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
if hasattr(model, "add_model_tags"):
model.add_model_tags(["axolotl"])
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
# self.model =
@property
def model_ref(self):
return self._model_ref

View File

@@ -621,7 +621,6 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
cur_device = get_device_type()
if "mps" in str(cur_device):
@@ -826,16 +825,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.tensor_parallel == "auto":
from accelerate import Accelerator
Accelerator()
rank = int(os.environ.get("LOCAL_RANK", 0))
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
@@ -1198,9 +1187,15 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
self.post_loading_set_env()
# TODO resume_from_checkpoint handling
return self.model, lora_config
def post_loading_set_env(self):
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
os.environ["ACCELERATE_USE_TP"] = "true"
def load_model(
cfg: DictDefault,