diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index cdc09ae0a..3efe45a51 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -2,15 +2,34 @@ # flake8: noqa +from axolotl.utils import make_lazy_getattr + from .base import AxolotlTrainer -from .dpo.trainer import AxolotlDPOTrainer -from .ebft.strided import AxolotlStridedEBFTTrainer -from .ebft.trainer import AxolotlEBFTTrainer -from .mamba import AxolotlMambaTrainer -from .trl import ( - AxolotlCPOTrainer, - AxolotlKTOTrainer, - AxolotlORPOTrainer, - AxolotlPRMTrainer, - AxolotlRewardTrainer, -) + +# noinspection PyUnresolvedReferences +__all__ = [ + "AxolotlTrainer", + "AxolotlCPOTrainer", + "AxolotlDPOTrainer", + "AxolotlEBFTTrainer", + "AxolotlKTOTrainer", + "AxolotlMambaTrainer", + "AxolotlORPOTrainer", + "AxolotlPRMTrainer", + "AxolotlRewardTrainer", + "AxolotlStridedEBFTTrainer", +] + +_LAZY_IMPORTS = { + "AxolotlDPOTrainer": ".dpo.trainer", + "AxolotlStridedEBFTTrainer": ".ebft.strided", + "AxolotlEBFTTrainer": ".ebft.trainer", + "AxolotlMambaTrainer": ".mamba", + "AxolotlCPOTrainer": ".trl", + "AxolotlKTOTrainer": ".trl", + "AxolotlORPOTrainer": ".trl", + "AxolotlPRMTrainer": ".trl", + "AxolotlRewardTrainer": ".trl", +} + +__getattr__ = make_lazy_getattr(_LAZY_IMPORTS, __name__, globals()) diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 476f3b2e9..d1d825c3a 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -9,6 +9,36 @@ import re import torch +def make_lazy_getattr( + lazy_imports: dict[str, str], module_name: str, module_globals: dict +): + """Create a module-level ``__getattr__`` that lazily imports symbols. + + Args: + lazy_imports: Mapping of attribute name to relative module path, + e.g. ``{"AxolotlDPOTrainer": ".dpo.trainer"}``. + module_name: The ``__name__`` of the calling module (used as the + anchor for relative imports). + module_globals: The ``globals()`` dict of the calling module, + used to cache resolved attributes so ``__getattr__`` is only + invoked once per name. + + Returns: + A ``__getattr__`` function suitable for assignment at module scope. + """ + import importlib + + def __getattr__(name: str): + if name in lazy_imports: + module = importlib.import_module(lazy_imports[name], module_name) + attr = getattr(module, name) + module_globals[name] = attr + return attr + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + + return __getattr__ + + def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None