need to update deepspeed version in extras too (#2161) [skip ci]

* need to update deepspeed version in extras too

* fix patch import

* fix monkeypatch reloading in tests and deepspeed patch

* remove duplicated functionality fixture

* reset LlamaForCausalLM too in fixtures for cce patch

* reset llama attn too

* disable xformers patch for cce

* skip problematic test on low usage functionality
This commit is contained in:
Wing Lian
2024-12-09 14:01:44 -05:00
committed by GitHub
parent 5d6b088997
commit ab4b32187d
10 changed files with 60 additions and 45 deletions

View File

@@ -71,7 +71,11 @@ class TestCutCrossEntropyIntegration:
@pytest.mark.parametrize(
"attention_type",
["flash_attention", "sdp_attention", "xformers_attention"],
[
"flash_attention",
"sdp_attention",
# "xformers_attention",
],
)
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault(