stuff
This commit is contained in:
@@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["axolotl"])
|
||||
|
||||
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
# self.model =
|
||||
|
||||
@property
|
||||
def model_ref(self):
|
||||
return self._model_ref
|
||||
|
||||
@@ -621,7 +621,6 @@ class ModelLoader:
|
||||
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
@@ -826,16 +825,6 @@ class ModelLoader:
|
||||
|
||||
_ = _configure_zero3_memory_efficient_loading()
|
||||
|
||||
if self.cfg.tensor_parallel == "auto":
|
||||
from accelerate import Accelerator
|
||||
|
||||
Accelerator()
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.distributed.init_process_group("nccl", device_id=device)
|
||||
|
||||
if self.cfg.is_multimodal:
|
||||
self.model_config.text_config = self.text_model_config
|
||||
self.model = self.AutoModelLoader.from_pretrained(
|
||||
@@ -1198,9 +1187,15 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.post_loading_set_env()
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return self.model, lora_config
|
||||
|
||||
def post_loading_set_env(self):
|
||||
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
|
||||
Reference in New Issue
Block a user