Remove unused imports
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -6,17 +6,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralAttention as OriginalMistralAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
MistralMLP
|
MistralMLP
|
||||||
|
|||||||
Reference in New Issue
Block a user