diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 5ac66260a..b81e713cf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -655,11 +655,19 @@ def load_model( if "device_map" in model_kwargs: del model_kwargs["device_map"] - model = AutoModelForCausalLM.from_pretrained( - base_model, - config=model_config, - **model_kwargs, - ) + if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0: + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + base_model, + config=model_config, + **model_kwargs, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + base_model, + config=model_config, + **model_kwargs, + ) if cfg.flash_attention and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import (