enable tensor parallel

This commit is contained in:
bursteratom
2024-12-11 10:38:14 -05:00
parent d009ead101
commit 42389c1f78
2 changed files with 2 additions and 1 deletions

View File

@@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel):
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto"
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")

View File

@@ -621,6 +621,7 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel
cur_device = get_device_type()
if "mps" in str(cur_device):