flash_attention + sample packing for stablelm 3b (#671)

* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting
This commit is contained in:
Wing Lian
2023-10-05 16:03:43 -04:00
committed by GitHub
parent eb480dfd68
commit 2d60ba3a6e
3 changed files with 429 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ import logging
from typing import Optional, Tuple
import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
@@ -17,7 +18,8 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)