Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
ca476d7f8e don't load the actual model when pre-loading to load modeling code 2023-09-20 13:37:32 -04:00

View File

@@ -6,6 +6,7 @@ import importlib
import logging
from typing import Optional, Tuple
import accelerate
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with accelerate.init_empty_weights():
AutoModelForCausalLM(model_config)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)