From 252dc5c91b809f33cd7b259bc77b4a534e2781d6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 1 Apr 2025 15:32:27 -0400 Subject: [PATCH] liger + torch compile fix --- src/axolotl/integrations/liger/__init__.py | 12 +++++++++ src/axolotl/integrations/liger/utils.py | 29 ++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 src/axolotl/integrations/liger/utils.py diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d6e423fa9..8e0508ee0 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -27,6 +27,7 @@ from axolotl.integrations.base import BasePlugin from ...utils.distributed import zero_only from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 +from .utils import patch_with_compile_disable LOG = logging.getLogger("axolotl.integrations.liger") @@ -40,6 +41,17 @@ class LigerPlugin(BasePlugin): return "axolotl.integrations.liger.LigerArgs" def pre_model_load(self, cfg): + if cfg.torch_compile: + import liger_kernel.ops.fused_linear_cross_entropy + + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_forward", + ) + patch_with_compile_disable( + liger_kernel.ops.fused_linear_cross_entropy, + "fused_linear_cross_entropy_backward", + ) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.geglu import LigerGEGLUMLP diff --git a/src/axolotl/integrations/liger/utils.py b/src/axolotl/integrations/liger/utils.py new file mode 100644 index 000000000..bf9fc58e7 --- /dev/null +++ b/src/axolotl/integrations/liger/utils.py @@ -0,0 +1,29 @@ +""" +utils to patch liger kernel ops to disable torch.compile +""" + +from functools import wraps + +import torch + + +def patch_with_compile_disable(module, function_name): + """ + Patch a function in a module by wrapping it with torch.compile.disable + + Args: + module: The module containing the function to patch + function_name: The name of the function to patch + """ + original_function = getattr(module, function_name) + + @wraps(original_function) + @torch.compiler.disable + def wrapped_function(*args, **kwargs): + return original_function(*args, **kwargs) + + # Replace the original function with the wrapped one + setattr(module, function_name, wrapped_function) + + # Return the original function in case you need to restore it later + return original_function