Add shifted sparse attention (#973) [skip-ci]

* Add s2_attn to hijack flash code

* Refactor code to account for s2_attn

* Add test for models utils

* Add ``s2_attention`` option to llama configs

* Add ``s2_attention`` option to README config

* Format code to appease linter

* chore: lint

* Remove xpos and llama-landmark [bad merge]

* add e2e smoke tests for shifted sparse attention

* remove stray patch from merge

* update yml with link to paper for s2_attention/longlora

* fix assertion check for full fine tune

* increase sequence len for tests and PR feedback updates

* reduce context len to 16k for tests

* reduce context len to 16k for tests

* reduce batch size for larger context len and udpate test to check message

* fix test for message

---------

Co-authored-by: joecummings <jrcummings@devvm050.nha0.facebook.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Joe Cummings
2024-01-18 10:16:07 -05:00
committed by GitHub
parent 317fa2555a
commit 1d70f24b50
10 changed files with 339 additions and 19 deletions

View File

@@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -52,6 +52,7 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4