stuff
This commit is contained in:
@@ -1319,6 +1319,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
# self.model =
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_ref(self):
|
def model_ref(self):
|
||||||
return self._model_ref
|
return self._model_ref
|
||||||
|
|||||||
@@ -621,7 +621,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.model_kwargs["device_map"] = device_map
|
self.model_kwargs["device_map"] = device_map
|
||||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||||
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
|
|
||||||
|
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "mps" in str(cur_device):
|
if "mps" in str(cur_device):
|
||||||
@@ -826,16 +825,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.tensor_parallel == "auto":
|
|
||||||
from accelerate import Accelerator
|
|
||||||
|
|
||||||
Accelerator()
|
|
||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
os.environ["RANK"] = str(rank)
|
|
||||||
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
|
||||||
torch.distributed.init_process_group("nccl", device_id=device)
|
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
@@ -1198,9 +1187,15 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.post_loading_set_env()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return self.model, lora_config
|
return self.model, lora_config
|
||||||
|
|
||||||
|
def post_loading_set_env(self):
|
||||||
|
if self.cfg.tensor_parallel == "auto" and self.model.supports_tp_plan:
|
||||||
|
os.environ["ACCELERATE_USE_TP"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
|
|||||||
Reference in New Issue
Block a user