lazy load trainer classes to prevent unnecesary imports (#3568)

* lazy load trainer classes to prevent unnecesary imports

* make the lazy load a common util
This commit is contained in:
Wing Lian
2026-04-01 13:29:04 -04:00
committed by GitHub
parent 1b1fc917bc
commit 6c92b5c31c
2 changed files with 60 additions and 11 deletions

View File

@@ -2,15 +2,34 @@
# flake8: noqa
from axolotl.utils import make_lazy_getattr
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .ebft.strided import AxolotlStridedEBFTTrainer
from .ebft.trainer import AxolotlEBFTTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
)
# noinspection PyUnresolvedReferences
__all__ = [
"AxolotlTrainer",
"AxolotlCPOTrainer",
"AxolotlDPOTrainer",
"AxolotlEBFTTrainer",
"AxolotlKTOTrainer",
"AxolotlMambaTrainer",
"AxolotlORPOTrainer",
"AxolotlPRMTrainer",
"AxolotlRewardTrainer",
"AxolotlStridedEBFTTrainer",
]
_LAZY_IMPORTS = {
"AxolotlDPOTrainer": ".dpo.trainer",
"AxolotlStridedEBFTTrainer": ".ebft.strided",
"AxolotlEBFTTrainer": ".ebft.trainer",
"AxolotlMambaTrainer": ".mamba",
"AxolotlCPOTrainer": ".trl",
"AxolotlKTOTrainer": ".trl",
"AxolotlORPOTrainer": ".trl",
"AxolotlPRMTrainer": ".trl",
"AxolotlRewardTrainer": ".trl",
}
__getattr__ = make_lazy_getattr(_LAZY_IMPORTS, __name__, globals())

View File

@@ -9,6 +9,36 @@ import re
import torch
def make_lazy_getattr(
lazy_imports: dict[str, str], module_name: str, module_globals: dict
):
"""Create a module-level ``__getattr__`` that lazily imports symbols.
Args:
lazy_imports: Mapping of attribute name to relative module path,
e.g. ``{"AxolotlDPOTrainer": ".dpo.trainer"}``.
module_name: The ``__name__`` of the calling module (used as the
anchor for relative imports).
module_globals: The ``globals()`` dict of the calling module,
used to cache resolved attributes so ``__getattr__`` is only
invoked once per name.
Returns:
A ``__getattr__`` function suitable for assignment at module scope.
"""
import importlib
def __getattr__(name: str):
if name in lazy_imports:
module = importlib.import_module(lazy_imports[name], module_name)
attr = getattr(module, name)
module_globals[name] = attr
return attr
raise AttributeError(f"module {module_name!r} has no attribute {name!r}")
return __getattr__
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None