lazy load trainer classes to prevent unnecesary imports (#3568)
* lazy load trainer classes to prevent unnecesary imports * make the lazy load a common util
This commit is contained in:
@@ -2,15 +2,34 @@
|
|||||||
|
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
|
from axolotl.utils import make_lazy_getattr
|
||||||
|
|
||||||
from .base import AxolotlTrainer
|
from .base import AxolotlTrainer
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
|
||||||
from .ebft.strided import AxolotlStridedEBFTTrainer
|
# noinspection PyUnresolvedReferences
|
||||||
from .ebft.trainer import AxolotlEBFTTrainer
|
__all__ = [
|
||||||
from .mamba import AxolotlMambaTrainer
|
"AxolotlTrainer",
|
||||||
from .trl import (
|
"AxolotlCPOTrainer",
|
||||||
AxolotlCPOTrainer,
|
"AxolotlDPOTrainer",
|
||||||
AxolotlKTOTrainer,
|
"AxolotlEBFTTrainer",
|
||||||
AxolotlORPOTrainer,
|
"AxolotlKTOTrainer",
|
||||||
AxolotlPRMTrainer,
|
"AxolotlMambaTrainer",
|
||||||
AxolotlRewardTrainer,
|
"AxolotlORPOTrainer",
|
||||||
)
|
"AxolotlPRMTrainer",
|
||||||
|
"AxolotlRewardTrainer",
|
||||||
|
"AxolotlStridedEBFTTrainer",
|
||||||
|
]
|
||||||
|
|
||||||
|
_LAZY_IMPORTS = {
|
||||||
|
"AxolotlDPOTrainer": ".dpo.trainer",
|
||||||
|
"AxolotlStridedEBFTTrainer": ".ebft.strided",
|
||||||
|
"AxolotlEBFTTrainer": ".ebft.trainer",
|
||||||
|
"AxolotlMambaTrainer": ".mamba",
|
||||||
|
"AxolotlCPOTrainer": ".trl",
|
||||||
|
"AxolotlKTOTrainer": ".trl",
|
||||||
|
"AxolotlORPOTrainer": ".trl",
|
||||||
|
"AxolotlPRMTrainer": ".trl",
|
||||||
|
"AxolotlRewardTrainer": ".trl",
|
||||||
|
}
|
||||||
|
|
||||||
|
__getattr__ = make_lazy_getattr(_LAZY_IMPORTS, __name__, globals())
|
||||||
|
|||||||
@@ -9,6 +9,36 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_lazy_getattr(
|
||||||
|
lazy_imports: dict[str, str], module_name: str, module_globals: dict
|
||||||
|
):
|
||||||
|
"""Create a module-level ``__getattr__`` that lazily imports symbols.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lazy_imports: Mapping of attribute name to relative module path,
|
||||||
|
e.g. ``{"AxolotlDPOTrainer": ".dpo.trainer"}``.
|
||||||
|
module_name: The ``__name__`` of the calling module (used as the
|
||||||
|
anchor for relative imports).
|
||||||
|
module_globals: The ``globals()`` dict of the calling module,
|
||||||
|
used to cache resolved attributes so ``__getattr__`` is only
|
||||||
|
invoked once per name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``__getattr__`` function suitable for assignment at module scope.
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name in lazy_imports:
|
||||||
|
module = importlib.import_module(lazy_imports[name], module_name)
|
||||||
|
attr = getattr(module, name)
|
||||||
|
module_globals[name] = attr
|
||||||
|
return attr
|
||||||
|
raise AttributeError(f"module {module_name!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
return __getattr__
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
return importlib.util.find_spec("mlflow") is not None
|
return importlib.util.find_spec("mlflow") is not None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user