bump transformers and update attention class map name (#1023)

* bump transformers and update attention class map name

* also run the tests in docker

* add mixtral e2e smoke test

* fix base name for docker image in test

* mixtral lora doesn't seem to work, at least check qlora

* add testcase for mixtral w sample packing

* check monkeypatch for flash attn multipack

* also run the e2e tests in docker

* use all gpus to run tests in docker ci

* use privileged mode too for docker w gpus

* rename the docker e2e actions for gh ci

* set privileged mode for docker and update mixtral model self attn check

* use fp16/bf16 for mixtral w fa2

* skip e2e tests on docker w gpus for now

* tests to validate mistral and mixtral patches

* fix rel import
This commit is contained in:
Wing Lian
2024-01-03 15:11:04 -05:00
committed by GitHub
parent 74532ddc45
commit bcc78d8fa3
8 changed files with 404 additions and 4 deletions

View File

@@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn():
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
)
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2

View File

@@ -261,7 +261,11 @@ def mixtral_model_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
if (
attention_mask is not None
and self._attn_implementation == "flash_attention_2"
and use_cache
):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@@ -270,7 +274,7 @@ def mixtral_model_forward(
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = (
attention_mask

View File

@@ -332,15 +332,18 @@ def load_model(
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)