From 42389c1f78f566683603ee0489f37ebe2c59474c Mon Sep 17 00:00:00 2001 From: bursteratom Date: Wed, 11 Dec 2024 10:38:14 -0500 Subject: [PATCH] enable tensor parallel --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 +- src/axolotl/utils/models.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3671e1bb9..6862d1190 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -393,7 +393,7 @@ class ModelInputConfig(BaseModel): default=None, json_schema_extra={"description": "transformers processor class"} ) trust_remote_code: Optional[bool] = None - + tensor_parallel: Optional[Union[Literal["auto"], bool]] = "auto" model_kwargs: Optional[Dict[str, Any]] = None @field_validator("trust_remote_code") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 11f4c6d0f..8f477ff16 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -621,6 +621,7 @@ class ModelLoader: self.model_kwargs["device_map"] = device_map self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + self.model_kwargs["tp_plan"] = self.cfg.tensor_parallel cur_device = get_device_type() if "mps" in str(cur_device):