enable tensor parallel
This commit is contained in:
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
|||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
|
|||||||
@@ -621,6 +621,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.model_kwargs["device_map"] = device_map
|
self.model_kwargs["device_map"] = device_map
|
||||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||||
|
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
|
||||||
|
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "mps" in str(cur_device):
|
if "mps" in str(cur_device):
|
||||||
|
|||||||
Reference in New Issue
Block a user