* qwen2 multipack support

* fix qwen derived model check so it doesn't break qwen2

* fixes to ensure qwen2 packing works

* bump requirements for qwen2

* requirements typo
This commit is contained in:
Wing Lian
2024-01-22 18:24:15 -05:00
committed by GitHub
parent fccb542b47
commit f5a828aa20
5 changed files with 31 additions and 16 deletions

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.7.0
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
accelerate==0.26.1
deepspeed
addict
fire

View File

@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type == "mixtral":
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -142,17 +142,12 @@ def normalize_config(cfg):
)
cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
) or cfg.is_qwen_derived_model
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)

View File

@@ -334,6 +334,14 @@ def load_model(
LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)
LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -426,14 +434,14 @@ def load_model(
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
or model_config.model_type == "mixtral"
or model_config.model_type in ["mixtral", "qwen2"]
):
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type == "mixtral":
if model_config.model_type in ["mixtral", "qwen2"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"