Replace with flash modules
This commit is contained in:
@@ -25,8 +25,11 @@ from transformers.models.llama.modeling_llama import (
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
||||
from axolotl.monkeypatch.flash_modules import flashattn_forward
|
||||
|
||||
from axolotl.monkeypatch.flash_modules import (
|
||||
flashattn_forward,
|
||||
replace_cross_entropy,
|
||||
replace_rms_norm
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -69,6 +72,10 @@ def replace_llama_attn_with_flash_attn(
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
llama_model_forward
|
||||
)
|
||||
if cross_entropy:
|
||||
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
||||
if rms_norm:
|
||||
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
|
||||
Reference in New Issue
Block a user