lazy load trainer classes to prevent unnecesary imports (#3568)
* lazy load trainer classes to prevent unnecesary imports * make the lazy load a common util
This commit is contained in:
@@ -2,15 +2,34 @@
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
from axolotl.utils import make_lazy_getattr
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .ebft.strided import AxolotlStridedEBFTTrainer
|
||||
from .ebft.trainer import AxolotlEBFTTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
__all__ = [
|
||||
"AxolotlTrainer",
|
||||
"AxolotlCPOTrainer",
|
||||
"AxolotlDPOTrainer",
|
||||
"AxolotlEBFTTrainer",
|
||||
"AxolotlKTOTrainer",
|
||||
"AxolotlMambaTrainer",
|
||||
"AxolotlORPOTrainer",
|
||||
"AxolotlPRMTrainer",
|
||||
"AxolotlRewardTrainer",
|
||||
"AxolotlStridedEBFTTrainer",
|
||||
]
|
||||
|
||||
_LAZY_IMPORTS = {
|
||||
"AxolotlDPOTrainer": ".dpo.trainer",
|
||||
"AxolotlStridedEBFTTrainer": ".ebft.strided",
|
||||
"AxolotlEBFTTrainer": ".ebft.trainer",
|
||||
"AxolotlMambaTrainer": ".mamba",
|
||||
"AxolotlCPOTrainer": ".trl",
|
||||
"AxolotlKTOTrainer": ".trl",
|
||||
"AxolotlORPOTrainer": ".trl",
|
||||
"AxolotlPRMTrainer": ".trl",
|
||||
"AxolotlRewardTrainer": ".trl",
|
||||
}
|
||||
|
||||
__getattr__ = make_lazy_getattr(_LAZY_IMPORTS, __name__, globals())
|
||||
|
||||
@@ -9,6 +9,36 @@ import re
|
||||
import torch
|
||||
|
||||
|
||||
def make_lazy_getattr(
|
||||
lazy_imports: dict[str, str], module_name: str, module_globals: dict
|
||||
):
|
||||
"""Create a module-level ``__getattr__`` that lazily imports symbols.
|
||||
|
||||
Args:
|
||||
lazy_imports: Mapping of attribute name to relative module path,
|
||||
e.g. ``{"AxolotlDPOTrainer": ".dpo.trainer"}``.
|
||||
module_name: The ``__name__`` of the calling module (used as the
|
||||
anchor for relative imports).
|
||||
module_globals: The ``globals()`` dict of the calling module,
|
||||
used to cache resolved attributes so ``__getattr__`` is only
|
||||
invoked once per name.
|
||||
|
||||
Returns:
|
||||
A ``__getattr__`` function suitable for assignment at module scope.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in lazy_imports:
|
||||
module = importlib.import_module(lazy_imports[name], module_name)
|
||||
attr = getattr(module, name)
|
||||
module_globals[name] = attr
|
||||
return attr
|
||||
raise AttributeError(f"module {module_name!r} has no attribute {name!r}")
|
||||
|
||||
return __getattr__
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user