enable tensor parallel
This commit is contained in:
@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
|
||||
@@ -621,6 +621,7 @@ class ModelLoader:
|
||||
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
|
||||
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
|
||||
Reference in New Issue
Block a user