liger + torch compile fix

This commit is contained in:
Wing Lian
2025-04-01 15:32:27 -04:00
parent af3f981f51
commit 252dc5c91b
2 changed files with 41 additions and 0 deletions

View File

@@ -27,6 +27,7 @@ from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger")
@@ -40,6 +41,17 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP

View File

@@ -0,0 +1,29 @@
"""
utils to patch liger kernel ops to disable torch.compile
"""
from functools import wraps
import torch
def patch_with_compile_disable(module, function_name):
"""
Patch a function in a module by wrapping it with torch.compile.disable
Args:
module: The module containing the function to patch
function_name: The name of the function to patch
"""
original_function = getattr(module, function_name)
@wraps(original_function)
@torch.compiler.disable
def wrapped_function(*args, **kwargs):
return original_function(*args, **kwargs)
# Replace the original function with the wrapped one
setattr(module, function_name, wrapped_function)
# Return the original function in case you need to restore it later
return original_function