don't load the actual model when pre-loading to load modeling code
This commit is contained in:
@@ -6,6 +6,7 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
|||||||
# this is a wonky hack to get the remotely loaded module
|
# this is a wonky hack to get the remotely loaded module
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_btlm to be available
|
# we need to load the model here in order for modeling_btlm to be available
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
with accelerate.init_empty_weights():
|
||||||
|
AutoModelForCausalLM(model_config)
|
||||||
module_name = model_config.__class__.__module__.replace(
|
module_name = model_config.__class__.__module__.replace(
|
||||||
".configuration_btlm", ".modeling_btlm"
|
".configuration_btlm", ".modeling_btlm"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user