Replace with flash modules
This commit is contained in:
@@ -25,8 +25,11 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
||||||
from axolotl.monkeypatch.flash_modules import flashattn_forward
|
from axolotl.monkeypatch.flash_modules import (
|
||||||
|
flashattn_forward,
|
||||||
|
replace_cross_entropy,
|
||||||
|
replace_rms_norm
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -69,6 +72,10 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
if cross_entropy:
|
||||||
|
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
||||||
|
if rms_norm:
|
||||||
|
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
|||||||
Reference in New Issue
Block a user