diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 05039c9ee..1b983f7d0 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -845,6 +845,9 @@ class ModelLoader: self.model._tp_size = self.cfg.tensor_parallel_size self.model._device_mesh = self.model_kwargs["device_mesh"] + if self.cfg.experimental_skip_move_to_device is not None: + skip_move_to_device = self.cfg.experimental_skip_move_to_device + return skip_move_to_device def _set_z3_leaf_modules(self): diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 5eea11444..eae8dacb6 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -62,6 +62,14 @@ class ModelInputConfig(BaseModel): json_schema_extra={"description": "Trust remote code for untrusted source"}, ) + experimental_skip_move_to_device: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Don't move the model to the device before sharding. " + "This is an experimental feature that may be included in the future as the default." + }, + ) + @field_validator("trust_remote_code") @classmethod def hint_trust_remote_code(cls, trust_remote_code):