reorg a bit

This commit is contained in:
Aman Karmani
2023-09-05 02:21:24 +00:00
parent 72a6fe1c1f
commit fc8766e502

View File

@@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
try:
from flash_attn.ops.rms_norm import RMSNorm
LOG.info("patching with flash_attn.ops.rms_norm")
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(