diff --git a/README.md b/README.md
index a7650ca5d..aec804da3 100644
--- a/README.md
+++ b/README.md
@@ -684,6 +684,8 @@ xformers_attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
+flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
+flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
diff --git a/examples/llama-2/README.md b/examples/llama-2/README.md
index 500872c61..2ddd711e2 100644
--- a/examples/llama-2/README.md
+++ b/examples/llama-2/README.md
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
micro_batch_size: 1
```shell
-accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
-
+accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
```
or
```shell
-accelerate launch scripts/finetune.py examples/llama-2/lora.yml
-
+accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
+```
+
+To launch a full finetuning with 16-bit precision:
+
+```shell
+accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
```
diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml
new file mode 100644
index 000000000..a96c1cfb8
--- /dev/null
+++ b/examples/llama-2/fft_optimized.yml
@@ -0,0 +1,73 @@
+base_model: NousResearch/Llama-2-7b-hf
+base_model_config: NousResearch/Llama-2-7b-hf
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+is_llama_derived_model: true
+
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+
+datasets:
+ - path: mhenrichsen/alpaca_2k_test
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+output_dir: ./out
+
+sequence_len: 4096
+sample_packing: true
+pad_to_sequence_len: true
+
+adapter:
+lora_model_dir:
+lora_r:
+lora_alpha:
+lora_dropout:
+lora_target_linear:
+lora_fan_in_fan_out:
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+
+gradient_accumulation_steps: 1
+micro_batch_size: 1
+num_epochs: 1
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: true
+fp16: false
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+flash_attn_cross_entropy: false
+flash_attn_rms_norm: true
+flash_attn_fuse_qkv: false
+flash_attn_fuse_mlp: true
+
+warmup_steps: 100
+eval_steps: 0.05
+eval_table_size:
+save_steps:
+debug:
+deepspeed: #deepspeed/zero2.json # multi-gpu only
+weight_decay: 0.1
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/src/axolotl/monkeypatch/fused_modules.py b/src/axolotl/monkeypatch/fused_modules.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
index 4f6b71575..386f4bfac 100644
--- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
+++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py
@@ -13,12 +13,18 @@ import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+from transformers.models.llama.modeling_llama import (
+ LlamaMLP,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from xformers.ops import SwiGLU
-from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
+from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -38,6 +44,28 @@ except ImportError:
LOG = logging.getLogger("axolotl")
+def replace_llama_mlp_with_swiglu(model):
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaMLP):
+ mlp = FusedMLP(
+ module.config, module.gate_proj, module.up_proj, module.down_proj
+ )
+ set_module_name(model, name, mlp)
+
+
+def replace_llama_qkv_with_fused(model):
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaAttention):
+ qkv = FusedAttention(
+ module.config,
+ module.q_proj,
+ module.k_proj,
+ module.v_proj,
+ module.o_proj,
+ )
+ set_module_name(model, name, qkv)
+
+
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
@@ -86,6 +114,91 @@ def replace_llama_attn_with_flash_attn(
)
+class FusedAttention(LlamaAttention):
+ """
+ Fused QKV Attention layer for incrementally improved training efficiency
+ """
+
+ def __init__(
+ self,
+ config,
+ q: torch.nn.Linear, # pylint: disable=invalid-name
+ k: torch.nn.Linear, # pylint: disable=invalid-name
+ v: torch.nn.Linear, # pylint: disable=invalid-name
+ o: torch.nn.Linear, # pylint: disable=invalid-name
+ ):
+ super().__init__(config)
+ self.config = config
+ self.init_device = next(iter(q.state_dict().values())).device
+
+ # define equivalent fused qkv projection
+ self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
+ self.qkv_proj = torch.nn.Linear(
+ q.in_features, sum(self.out_features), device=self.init_device, bias=False
+ )
+ self.o_proj = o
+
+ # overwrite initialized weights with pretrained weights
+ self.qkv_proj.weight.data = torch.cat(
+ (q.weight.data, k.weight.data, v.weight.data), dim=0
+ )
+
+ def _post_training(self, model, name):
+ q_proj, k_proj, v_proj = torch.split(
+ self.qkv_proj.weight.data, self.out_features, dim=0
+ )
+
+ new_attn = LlamaAttention(self.config)
+ new_attn.q_proj.weight.data = q_proj
+ new_attn.k_proj.weight.data = k_proj
+ new_attn.v_proj.weight.data = v_proj
+
+ set_module_name(model, name, new_attn)
+
+
+class FusedMLP(torch.nn.Module):
+ """
+ Fused MLP layer for incrementally improved training efficiency
+ """
+
+ def __init__(
+ self,
+ config,
+ gate_proj: torch.nn.Linear,
+ up_proj: torch.nn.Linear,
+ down_proj: torch.nn.Linear,
+ ):
+ super().__init__()
+ self.config = config
+ self.swiglu = SwiGLU(
+ in_features=config.hidden_size,
+ hidden_features=config.intermediate_size,
+ bias=False,
+ _pack_weights=True,
+ )
+ # overwrite initialized weights with pretrained weights
+ self.swiglu.w12.weight.data = torch.cat(
+ (gate_proj.weight.data, up_proj.weight.data), dim=0
+ )
+ self.swiglu.w3.weight.data = down_proj.weight.data
+
+ def _post_training(self, model, name):
+ w1, w2 = torch.split( # pylint: disable=invalid-name
+ self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
+ )
+
+ # Assign the split weights back to the original layers
+ new_mlp = LlamaMLP(self.config)
+ new_mlp.gate_proj.weight.data = w1
+ new_mlp.up_proj.weight.data = w2
+ new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
+
+ set_module_name(model, name, new_mlp)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
+ return self.swiglu(x)
+
+
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
@@ -147,9 +260,14 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1)
else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ if isinstance(self, FusedAttention):
+ query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
+ self.out_features, dim=-1
+ )
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py
index 3b007e05d..b352cc55e 100644
--- a/src/axolotl/monkeypatch/utils.py
+++ b/src/axolotl/monkeypatch/utils.py
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
+
+
+def set_module_name(model, name, value):
+ if "." in name:
+ parent_name = name.rsplit(".", 1)[0]
+ child_name = name[len(parent_name) + 1 :]
+ parent = model.get_submodule(parent_name)
+ else:
+ parent_name = ""
+ parent = model
+ child_name = name
+
+ setattr(parent, child_name, value)
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 80acddb9c..468d25e14 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -40,10 +40,7 @@ class TrainDatasetMeta:
def train(
- *,
- cfg: DictDefault,
- cli_args: TrainerCliArgs,
- dataset_meta: TrainDatasetMeta,
+ *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
@@ -120,6 +117,11 @@ def train(
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
+ # post training
+ for name, module in model.named_modules():
+ if hasattr(module, "_post_training"):
+ module._post_training(model, name) # pylint: disable=protected-access
+
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index 9503d838c..82e2a5117 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -189,9 +189,15 @@ def validate_config(cfg):
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
+ if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
+ raise ValueError("Fused modules are not supported with QLoRA")
+
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
+ if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
+ raise ValueError("Fused modules are not supported with LoRA")
+
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -205,6 +211,9 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
+ if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
+ raise ValueError("Fused modules are not supported with ReLoRA")
+
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index bccb8b8e5..ea21ce8f9 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -272,6 +272,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
+
+ if cfg.flash_attention and not inference:
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
+ replace_llama_mlp_with_swiglu,
+ replace_llama_qkv_with_fused,
+ )
+
+ if cfg.flash_attn_fuse_mlp:
+ LOG.info("patching with SwiGLU")
+ replace_llama_mlp_with_swiglu(model)
+
+ if cfg.flash_attn_fuse_qkv:
+ LOG.info("patching with fused QKV")
+ replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
diff --git a/tests/e2e/test_fused_llama.py b/tests/e2e/test_fused_llama.py
new file mode 100644
index 000000000..4cd45ecfd
--- /dev/null
+++ b/tests/e2e/test_fused_llama.py
@@ -0,0 +1,117 @@
+"""
+E2E tests for lora llama
+"""
+
+import logging
+import os
+import tempfile
+import unittest
+from pathlib import Path
+
+from transformers.utils import is_torch_bf16_gpu_available
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class TestFusedLlama(unittest.TestCase):
+ """
+ Test case for Llama models using Fused layers
+ """
+
+ def test_lora_packing(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "base_model_config": "JackFram/llama-68m",
+ "flash_attention": True,
+ "flash_attn_fuse_qkv": True,
+ "flash_attn_fuse_mlp": True,
+ "sample_packing": True,
+ "sequence_len": 1024,
+ "load_in_8bit": True,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
+
+ def test_fft_packing(self):
+ # pylint: disable=duplicate-code
+ output_dir = tempfile.mkdtemp()
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "base_model_config": "JackFram/llama-68m",
+ "flash_attention": True,
+ "flash_attn_fuse_qkv": True,
+ "flash_attn_fuse_mlp": True,
+ "sample_packing": True,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "special_tokens": {
+ "unk_token": "",
+ "bos_token": "",
+ "eos_token": "",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": output_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ }
+ )
+ if is_torch_bf16_gpu_available():
+ cfg.bf16 = True
+ else:
+ cfg.fp16 = True
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(output_dir) / "pytorch_model.bin").exists()