From 8021c718ce8c2494c7ebc3f01523cdce4508f196 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 6 Aug 2025 00:13:12 -0400 Subject: [PATCH] use skip_move_to_device for all cases (#3015) * use skip_move_to_device for all cases * use experimental option for skip move --- src/axolotl/loaders/model.py | 3 +++ src/axolotl/utils/schemas/model.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 05039c9ee..1b983f7d0 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -845,6 +845,9 @@ class ModelLoader: self.model._tp_size = self.cfg.tensor_parallel_size self.model._device_mesh = self.model_kwargs["device_mesh"] + if self.cfg.experimental_skip_move_to_device is not None: + skip_move_to_device = self.cfg.experimental_skip_move_to_device + return skip_move_to_device def _set_z3_leaf_modules(self): diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 5eea11444..eae8dacb6 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -62,6 +62,14 @@ class ModelInputConfig(BaseModel): json_schema_extra={"description": "Trust remote code for untrusted source"}, ) + experimental_skip_move_to_device: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Don't move the model to the device before sharding. " + "This is an experimental feature that may be included in the future as the default." + }, + ) + @field_validator("trust_remote_code") @classmethod def hint_trust_remote_code(cls, trust_remote_code):