Replace with flash modules

This commit is contained in:
Casper
2023-12-07 19:48:40 +01:00
parent 5302d2d534
commit bf289123e9

View File

@@ -25,8 +25,11 @@ from transformers.models.llama.modeling_llama import (
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
from axolotl.monkeypatch.flash_modules import flashattn_forward
from axolotl.monkeypatch.flash_modules import (
flashattn_forward,
replace_cross_entropy,
replace_rms_norm
)
LOG = logging.getLogger("axolotl")
@@ -69,6 +72,10 @@ def replace_llama_attn_with_flash_attn(
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
if cross_entropy:
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
if rms_norm:
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
# Disable the transformation of the attention mask in LlamaModel as the flash attention