From ca476d7f8eb4b97901dc631fb31b98e490e4b07b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 20 Sep 2023 13:37:32 -0400 Subject: [PATCH] don't load the actual model when pre-loading to load modeling code --- src/axolotl/monkeypatch/btlm_attn_hijack_flash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index be5a70559..137724765 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -6,6 +6,7 @@ import importlib import logging from typing import Optional, Tuple +import accelerate import torch from flash_attn.flash_attn_interface import flash_attn_func from transformers import AutoConfig, AutoModelForCausalLM @@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): # this is a wonky hack to get the remotely loaded module model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) # we need to load the model here in order for modeling_btlm to be available - AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + with accelerate.init_empty_weights(): + AutoModelForCausalLM(model_config) module_name = model_config.__class__.__module__.replace( ".configuration_btlm", ".modeling_btlm" )