Falcon embeddings (#1149) [skip docker]
* also fix multipack for falcon and add smoke tests * make sure to handle special tokens and added tokens for lora * fix reference to model_type * fix tests for falcon * fix stray typo * fixes for smoke tests
This commit is contained in:
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for falcon
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_falcon_attn_with_multipack_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
|
||||
return ["embd.wte", "lm_head.linear"]
|
||||
if model_type == "gpt_neox":
|
||||
return ["embed_in", "embed_out"]
|
||||
if model_type == "falcon":
|
||||
return ["word_embeddings", "lm_head"]
|
||||
return ["embed_tokens", "lm_head"]
|
||||
|
||||
@@ -334,6 +334,14 @@ def load_model(
|
||||
LOG.info("patching mixtral with flash attention")
|
||||
replace_mixtral_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.falcon import (
|
||||
replace_falcon_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching falcon with flash attention")
|
||||
replace_falcon_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.qwen2 import (
|
||||
replace_qwen2_attn_with_multipack_flash_attn,
|
||||
@@ -434,18 +442,13 @@ def load_model(
|
||||
if not cfg.sample_packing:
|
||||
if cfg.s2_attention:
|
||||
pass
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
or model_config.model_type in ["mixtral", "qwen2"]
|
||||
):
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
# most other models support flash attention, we can define exceptions as they come up
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type in ["mixtral", "qwen2"]:
|
||||
if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
@@ -461,7 +464,11 @@ def load_model(
|
||||
model_config.fused_dense = True
|
||||
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
if (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
@@ -755,8 +762,10 @@ def find_all_linear_names(model):
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
embedding_modules = get_linear_embedding_layers(model.config.model_type)
|
||||
output_embedding = embedding_modules[1]
|
||||
if output_embedding in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove(output_embedding)
|
||||
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
@@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
if cfg.model_config_type == "falcon":
|
||||
LOG.info("dropping token_type_ids column")
|
||||
train_dataset = train_dataset.remove_columns("token_type_ids")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
|
||||
Reference in New Issue
Block a user