diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index b546793f3..108e30bab 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -3,7 +3,6 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py import logging -from functools import partial from typing import List, Optional, Tuple, Union import torch diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 16a4e557d..98f747e66 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -6,17 +6,7 @@ from typing import List, Optional, Tuple, Union import torch import transformers -from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports - flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, - flash_attn_varlen_qkvpacked_func, -) from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralAttention as OriginalMistralAttention, -) from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OriginalMistralDecoderLayer, MistralMLP