reorg a bit
This commit is contained in:
@@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
|
||||
Reference in New Issue
Block a user