reorg a bit
This commit is contained in:
@@ -64,14 +64,13 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|||||||
try:
|
try:
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
class LlamaRMSNorm(RMSNorm):
|
||||||
"""Patched LLamaRMSNorm"""
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__(hidden_size, eps=eps)
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user