Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)

* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Casper
2023-10-16 21:13:46 +02:00
committed by GitHub
parent e1b214c62b
commit a045db0214
2 changed files with 105 additions and 6 deletions

View File

@@ -46,7 +46,7 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn>=2.2.1",
"flash-attn>=2.3.0",
],
"deepspeed": [
"deepspeed",