From 15d3a654bfdc572a3bf44b9b0a74b47f8688d33c Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 21 Oct 2023 22:08:25 +0200 Subject: [PATCH] Implement fused modules (#747) * MLP: Memory saving * Remove RMSNorm restrictions * Map packed weights to original * FusedAttention module * Simplify code * Move fused modules * Fix critical typo * Split inplace * Add FFT config * Add validation of fused arguments * Add fused arguments to config * Update docs * Fix validation logic * Add fused modules to flash attn * Only fuse during training * Remove timing * Formatting * Formatting * Formatting * chore: lint * chore: lint * add e2e tests for fused llama * no lora for tests --------- Co-authored-by: Wing Lian --- README.md | 2 + examples/llama-2/README.md | 12 +- examples/llama-2/fft_optimized.yml | 73 ++++++++++ src/axolotl/monkeypatch/fused_modules.py | 0 .../monkeypatch/llama_attn_hijack_flash.py | 128 +++++++++++++++++- src/axolotl/monkeypatch/utils.py | 13 ++ src/axolotl/train.py | 10 +- src/axolotl/utils/config.py | 9 ++ src/axolotl/utils/models.py | 14 ++ tests/e2e/test_fused_llama.py | 117 ++++++++++++++++ 10 files changed, 365 insertions(+), 13 deletions(-) create mode 100644 examples/llama-2/fft_optimized.yml create mode 100644 src/axolotl/monkeypatch/fused_modules.py create mode 100644 tests/e2e/test_fused_llama.py diff --git a/README.md b/README.md index a7650ca5d..aec804da3 100644 --- a/README.md +++ b/README.md @@ -684,6 +684,8 @@ xformers_attention: flash_attention: flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only +flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation +flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation # Whether to use scaled-dot-product attention # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html sdp_attention: diff --git a/examples/llama-2/README.md b/examples/llama-2/README.md index 500872c61..2ddd711e2 100644 --- a/examples/llama-2/README.md +++ b/examples/llama-2/README.md @@ -9,12 +9,16 @@ gradient_accumulation_steps: 2 micro_batch_size: 1 ```shell -accelerate launch scripts/finetune.py examples/llama-2/qlora.yml - +accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml ``` or ```shell -accelerate launch scripts/finetune.py examples/llama-2/lora.yml - +accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml +``` + +To launch a full finetuning with 16-bit precision: + +```shell +accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml ``` diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml new file mode 100644 index 000000000..a96c1cfb8 --- /dev/null +++ b/examples/llama-2/fft_optimized.yml @@ -0,0 +1,73 @@ +base_model: NousResearch/Llama-2-7b-hf +base_model_config: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.01 +output_dir: ./out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: +lora_model_dir: +lora_r: +lora_alpha: +lora_dropout: +lora_target_linear: +lora_fan_in_fan_out: + +wandb_project: +wandb_entity: +wandb_watch: +wandb_run_id: +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: true +fp16: false +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +flash_attn_cross_entropy: false +flash_attn_rms_norm: true +flash_attn_fuse_qkv: false +flash_attn_fuse_mlp: true + +warmup_steps: 100 +eval_steps: 0.05 +eval_table_size: +save_steps: +debug: +deepspeed: #deepspeed/zero2.json # multi-gpu only +weight_decay: 0.1 +fsdp: +fsdp_config: +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" diff --git a/src/axolotl/monkeypatch/fused_modules.py b/src/axolotl/monkeypatch/fused_modules.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4f6b71575..386f4bfac 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -13,12 +13,18 @@ import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OriginalLlamaDecoderLayer, ) -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.models.llama.modeling_llama import ( + LlamaMLP, + apply_rotary_pos_emb, + repeat_kv, +) +from xformers.ops import SwiGLU -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports @@ -38,6 +44,28 @@ except ImportError: LOG = logging.getLogger("axolotl") +def replace_llama_mlp_with_swiglu(model): + for name, module in model.named_modules(): + if isinstance(module, LlamaMLP): + mlp = FusedMLP( + module.config, module.gate_proj, module.up_proj, module.down_proj + ) + set_module_name(model, name, mlp) + + +def replace_llama_qkv_with_fused(model): + for name, module in model.named_modules(): + if isinstance(module, LlamaAttention): + qkv = FusedAttention( + module.config, + module.q_proj, + module.k_proj, + module.v_proj, + module.o_proj, + ) + set_module_name(model, name, qkv) + + def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, @@ -86,6 +114,91 @@ def replace_llama_attn_with_flash_attn( ) +class FusedAttention(LlamaAttention): + """ + Fused QKV Attention layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + q: torch.nn.Linear, # pylint: disable=invalid-name + k: torch.nn.Linear, # pylint: disable=invalid-name + v: torch.nn.Linear, # pylint: disable=invalid-name + o: torch.nn.Linear, # pylint: disable=invalid-name + ): + super().__init__(config) + self.config = config + self.init_device = next(iter(q.state_dict().values())).device + + # define equivalent fused qkv projection + self.out_features: List[int] = [q.out_features, k.out_features, v.out_features] + self.qkv_proj = torch.nn.Linear( + q.in_features, sum(self.out_features), device=self.init_device, bias=False + ) + self.o_proj = o + + # overwrite initialized weights with pretrained weights + self.qkv_proj.weight.data = torch.cat( + (q.weight.data, k.weight.data, v.weight.data), dim=0 + ) + + def _post_training(self, model, name): + q_proj, k_proj, v_proj = torch.split( + self.qkv_proj.weight.data, self.out_features, dim=0 + ) + + new_attn = LlamaAttention(self.config) + new_attn.q_proj.weight.data = q_proj + new_attn.k_proj.weight.data = k_proj + new_attn.v_proj.weight.data = v_proj + + set_module_name(model, name, new_attn) + + +class FusedMLP(torch.nn.Module): + """ + Fused MLP layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + gate_proj: torch.nn.Linear, + up_proj: torch.nn.Linear, + down_proj: torch.nn.Linear, + ): + super().__init__() + self.config = config + self.swiglu = SwiGLU( + in_features=config.hidden_size, + hidden_features=config.intermediate_size, + bias=False, + _pack_weights=True, + ) + # overwrite initialized weights with pretrained weights + self.swiglu.w12.weight.data = torch.cat( + (gate_proj.weight.data, up_proj.weight.data), dim=0 + ) + self.swiglu.w3.weight.data = down_proj.weight.data + + def _post_training(self, model, name): + w1, w2 = torch.split( # pylint: disable=invalid-name + self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 + ) + + # Assign the split weights back to the original layers + new_mlp = LlamaMLP(self.config) + new_mlp.gate_proj.weight.data = w1 + new_mlp.up_proj.weight.data = w2 + new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data + + set_module_name(model, name, new_mlp) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return self.swiglu(x) + + # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( @@ -147,9 +260,14 @@ def flashattn_forward( value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if isinstance(self, FusedAttention): + query_states, key_states, value_states = self.qkv_proj(hidden_states).split( + self.out_features, dim=-1 + ) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view( bsz, q_len, self.num_heads, self.head_dim diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 3b007e05d..b352cc55e 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids): max_seq_lens.append(max_seq_len) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) + + +def set_module_name(model, name, value): + if "." in name: + parent_name = name.rsplit(".", 1)[0] + child_name = name[len(parent_name) + 1 :] + parent = model.get_submodule(parent_name) + else: + parent_name = "" + parent = model + child_name = name + + setattr(parent, child_name, value) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 80acddb9c..468d25e14 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -40,10 +40,7 @@ class TrainDatasetMeta: def train( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, - dataset_meta: TrainDatasetMeta, + *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta ): # load the tokenizer first LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") @@ -120,6 +117,11 @@ def train( LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + # post training + for name, module in model.named_modules(): + if hasattr(module, "_post_training"): + module._post_training(model, name) # pylint: disable=protected-access + if trainer.is_fsdp_enabled: trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 9503d838c..82e2a5117 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -189,9 +189,15 @@ def validate_config(cfg): if not cfg.load_in_4bit: raise ValueError("Require cfg.load_in_4bit to be True for qlora") + if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with QLoRA") + if not cfg.load_in_8bit and cfg.adapter == "lora": LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp): + raise ValueError("Fused modules are not supported with LoRA") + if cfg.relora_steps: if cfg.adapter not in ("lora", "qlora"): raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") @@ -205,6 +211,9 @@ def validate_config(cfg): if cfg.lr_scheduler == "one_cycle": raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") + if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with ReLoRA") + if cfg.trust_remote_code: LOG.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bccb8b8e5..ea21ce8f9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -272,6 +272,20 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) + + if cfg.flash_attention and not inference: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_mlp_with_swiglu, + replace_llama_qkv_with_fused, + ) + + if cfg.flash_attn_fuse_mlp: + LOG.info("patching with SwiGLU") + replace_llama_mlp_with_swiglu(model) + + if cfg.flash_attn_fuse_qkv: + LOG.info("patching with fused QKV") + replace_llama_qkv_with_fused(model) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # This is a WIP, still an issue with the backward pass # RuntimeError: grad can be implicitly created only for scalar outputs diff --git a/tests/e2e/test_fused_llama.py b/tests/e2e/test_fused_llama.py new file mode 100644 index 000000000..4cd45ecfd --- /dev/null +++ b/tests/e2e/test_fused_llama.py @@ -0,0 +1,117 @@ +""" +E2E tests for lora llama +""" + +import logging +import os +import tempfile +import unittest +from pathlib import Path + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestFusedLlama(unittest.TestCase): + """ + Test case for Llama models using Fused layers + """ + + def test_lora_packing(self): + # pylint: disable=duplicate-code + output_dir = tempfile.mkdtemp() + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "base_model_config": "JackFram/llama-68m", + "flash_attention": True, + "flash_attn_fuse_qkv": True, + "flash_attn_fuse_mlp": True, + "sample_packing": True, + "sequence_len": 1024, + "load_in_8bit": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": output_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(output_dir) / "pytorch_model.bin").exists() + + def test_fft_packing(self): + # pylint: disable=duplicate-code + output_dir = tempfile.mkdtemp() + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "base_model_config": "JackFram/llama-68m", + "flash_attention": True, + "flash_attn_fuse_qkv": True, + "flash_attn_fuse_mlp": True, + "sample_packing": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": output_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(output_dir) / "pytorch_model.bin").exists()