Falcon embeddings (#1149) [skip docker]
* also fix multipack for falcon and add smoke tests * make sure to handle special tokens and added tokens for lora * fix reference to model_type * fix tests for falcon * fix stray typo * fixes for smoke tests
This commit is contained in:
@@ -60,5 +60,5 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
bos_token: ">>ABSTRACT<<"
|
bos_token: "<|endoftext|>"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|||||||
@@ -89,5 +89,5 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
bos_token: ">>ABSTRACT<<"
|
bos_token: "<|endoftext|>"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|||||||
@@ -60,5 +60,5 @@ fsdp:
|
|||||||
fsdp_config:
|
fsdp_config:
|
||||||
special_tokens:
|
special_tokens:
|
||||||
pad_token: "<|endoftext|>"
|
pad_token: "<|endoftext|>"
|
||||||
bos_token: ">>ABSTRACT<<"
|
bos_token: "<|endoftext|>"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|||||||
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Patches to support multipack for falcon
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
|
|
||||||
|
def replace_falcon_attn_with_multipack_flash_attn():
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
@@ -11,4 +11,6 @@ def get_linear_embedding_layers(model_type):
|
|||||||
return ["embd.wte", "lm_head.linear"]
|
return ["embd.wte", "lm_head.linear"]
|
||||||
if model_type == "gpt_neox":
|
if model_type == "gpt_neox":
|
||||||
return ["embed_in", "embed_out"]
|
return ["embed_in", "embed_out"]
|
||||||
|
if model_type == "falcon":
|
||||||
|
return ["word_embeddings", "lm_head"]
|
||||||
return ["embed_tokens", "lm_head"]
|
return ["embed_tokens", "lm_head"]
|
||||||
|
|||||||
@@ -334,6 +334,14 @@ def load_model(
|
|||||||
LOG.info("patching mixtral with flash attention")
|
LOG.info("patching mixtral with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
|
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.falcon import (
|
||||||
|
replace_falcon_attn_with_multipack_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching falcon with flash attention")
|
||||||
|
replace_falcon_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.qwen2 import (
|
from axolotl.monkeypatch.qwen2 import (
|
||||||
replace_qwen2_attn_with_multipack_flash_attn,
|
replace_qwen2_attn_with_multipack_flash_attn,
|
||||||
@@ -434,18 +442,13 @@ def load_model(
|
|||||||
if not cfg.sample_packing:
|
if not cfg.sample_packing:
|
||||||
if cfg.s2_attention:
|
if cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
if (
|
# most other models support flash attention, we can define exceptions as they come up
|
||||||
cfg.is_llama_derived_model
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
or cfg.is_falcon_derived_model
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
or cfg.is_mistral_derived_model
|
"flash_attention_2"
|
||||||
or model_config.model_type in ["mixtral", "qwen2"]
|
)
|
||||||
):
|
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flash_attention_2"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if model_config.model_type in ["mixtral", "qwen2"]:
|
if model_config.model_type in ["mixtral", "qwen2", "falcon"]:
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
@@ -461,7 +464,11 @@ def load_model(
|
|||||||
model_config.fused_dense = True
|
model_config.fused_dense = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if (
|
||||||
|
model_config.model_type == "llama"
|
||||||
|
and not cfg.trust_remote_code
|
||||||
|
and not cfg.gptq
|
||||||
|
):
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
@@ -755,8 +762,10 @@ def find_all_linear_names(model):
|
|||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
embedding_modules = get_linear_embedding_layers(model.config.model_type)
|
||||||
lora_module_names.remove("lm_head")
|
output_embedding = embedding_modules[1]
|
||||||
|
if output_embedding in lora_module_names: # needed for 16-bit
|
||||||
|
lora_module_names.remove(output_embedding)
|
||||||
|
|
||||||
return list(lora_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|||||||
@@ -124,6 +124,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
|
|
||||||
|
if cfg.model_config_type == "falcon":
|
||||||
|
LOG.info("dropping token_type_ids column")
|
||||||
|
train_dataset = train_dataset.remove_columns("token_type_ids")
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
|
|||||||
112
tests/e2e/patched/test_falcon_samplepack.py
Normal file
112
tests/e2e/patched/test_falcon_samplepack.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for falcon
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFalconPatched(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Falcon models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qlora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": ["word_embeddings", "lm_head"],
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<|endoftext|>",
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<|endoftext|>",
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
@@ -32,6 +32,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
@@ -57,7 +58,6 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -76,6 +76,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||||
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
@@ -95,7 +96,6 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"save_steps": 10,
|
"save_steps": 10,
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
166
tests/e2e/test_falcon.py
Normal file
166
tests/e2e/test_falcon.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for falcon
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFalcon(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for falcon
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": [
|
||||||
|
"word_embeddings",
|
||||||
|
"lm_head",
|
||||||
|
],
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<|endoftext|>",
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_added_vocab(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 32,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": [
|
||||||
|
"word_embeddings",
|
||||||
|
"lm_head",
|
||||||
|
],
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<|endoftext|>",
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
],
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_ft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"bos_token": "<|endoftext|>",
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
Reference in New Issue
Block a user