Falcon embeddings (#1149) [skip docker]

* also fix multipack for falcon and add smoke tests

* make sure to handle special tokens and added tokens for lora

* fix reference to model_type

* fix tests for falcon

* fix stray typo

* fixes for smoke tests
This commit is contained in:
Wing Lian
2024-01-22 21:01:42 -05:00
committed by GitHub
parent 0f77b8d798
commit e799e08d3c
10 changed files with 326 additions and 19 deletions

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
return ["embd.wte", "lm_head.linear"]
if model_type == "gpt_neox":
return ["embed_in", "embed_out"]
if model_type == "falcon":
return ["word_embeddings", "lm_head"]
return ["embed_tokens", "lm_head"]

View File

@@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)
LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
@@ -434,18 +442,13 @@ def load_model(
if not cfg.sample_packing:
if cfg.s2_attention:
pass
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2"]:
if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
@@ -461,7 +464,11 @@ def load_model(
model_config.fused_dense = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if (
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
@@ -755,8 +762,10 @@ def find_all_linear_names(model):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
embedding_modules = get_linear_embedding_layers(model.config.model_type)
output_embedding = embedding_modules[1]
if output_embedding in lora_module_names: # needed for 16-bit
lora_module_names.remove(output_embedding)
return list(lora_module_names)

View File

@@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column")
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,