Qwen2 (#1166)
* qwen2 multipack support * fix qwen derived model check so it doesn't break qwen2 * fixes to ensure qwen2 packing works * bump requirements for qwen2 * requirements typo
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.7.0
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
||||
transformers==4.37.0
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
|
||||
accelerate==0.26.1
|
||||
deepspeed
|
||||
addict
|
||||
fire
|
||||
|
||||
@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type == "mixtral":
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
|
||||
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for qwen2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -142,17 +142,12 @@ def normalize_config(cfg):
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"qwen",
|
||||
]
|
||||
)
|
||||
or cfg.is_qwen_derived_model
|
||||
or "qwen" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
||||
)
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"qwen",
|
||||
]
|
||||
) or cfg.is_qwen_derived_model
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
@@ -334,6 +334,14 @@ def load_model(
|
||||
LOG.info("patching mixtral with flash attention")
|
||||
replace_mixtral_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.qwen2 import (
|
||||
replace_qwen2_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching qwen2 with flash attention")
|
||||
replace_qwen2_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
@@ -426,14 +434,14 @@ def load_model(
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
or model_config.model_type == "mixtral"
|
||||
or model_config.model_type in ["mixtral", "qwen2"]
|
||||
):
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type == "mixtral":
|
||||
if model_config.model_type in ["mixtral", "qwen2"]:
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
|
||||
Reference in New Issue
Block a user