Qwen2 (#1166)
* qwen2 multipack support * fix qwen derived model check so it doesn't break qwen2 * fixes to ensure qwen2 packing works * bump requirements for qwen2 * requirements typo
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.7.0
|
peft==0.7.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
transformers==4.37.0
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate.git@0d2280dadc6a93413a5496613b7fdda3a4d2551b
|
accelerate==0.26.1
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
|
|||||||
@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type == "mixtral":
|
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
else:
|
else:
|
||||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
|||||||
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Patches to support multipack for qwen2
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
|
|
||||||
|
def replace_qwen2_attn_with_multipack_flash_attn():
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
@@ -142,17 +142,12 @@ def normalize_config(cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
cfg.is_qwen_derived_model = (
|
cfg.is_qwen_derived_model = (
|
||||||
(
|
hasattr(model_config, "model_type")
|
||||||
hasattr(model_config, "model_type")
|
and model_config.model_type
|
||||||
and model_config.model_type
|
in [
|
||||||
in [
|
"qwen",
|
||||||
"qwen",
|
]
|
||||||
]
|
) or cfg.is_qwen_derived_model
|
||||||
)
|
|
||||||
or cfg.is_qwen_derived_model
|
|
||||||
or "qwen" in cfg.base_model.lower()
|
|
||||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cfg.learning_rate, str):
|
if isinstance(cfg.learning_rate, str):
|
||||||
cfg.learning_rate = float(cfg.learning_rate)
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|||||||
@@ -334,6 +334,14 @@ def load_model(
|
|||||||
LOG.info("patching mixtral with flash attention")
|
LOG.info("patching mixtral with flash attention")
|
||||||
replace_mixtral_attn_with_multipack_flash_attn()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
|
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.qwen2 import (
|
||||||
|
replace_qwen2_attn_with_multipack_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching qwen2 with flash attention")
|
||||||
|
replace_qwen2_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
@@ -426,14 +434,14 @@ def load_model(
|
|||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
or cfg.is_falcon_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
or model_config.model_type == "mixtral"
|
or model_config.model_type in ["mixtral", "qwen2"]
|
||||||
):
|
):
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type == "mixtral":
|
if model_config.model_type in ["mixtral", "qwen2"]:
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
|
|||||||
Reference in New Issue
Block a user