handle tp load
This commit is contained in:
@@ -650,6 +650,14 @@ class ModelLoader:
|
|||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
|
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||||
|
if "device_map" in self.model_kwargs:
|
||||||
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|||||||
Reference in New Issue
Block a user