fsdp fft loading on meta device
This commit is contained in:
@@ -655,11 +655,19 @@ def load_model(
|
|||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if cfg.fsdp and not cfg.adapter and cfg.local_rank != 0:
|
||||||
base_model,
|
with init_empty_weights():
|
||||||
config=model_config,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
**model_kwargs,
|
base_model,
|
||||||
)
|
config=model_config,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
config=model_config,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.flash_attention and not inference:
|
if cfg.flash_attention and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
|||||||
Reference in New Issue
Block a user