set config on the PluginManager for callback access (#2587)

This commit is contained in:
Wing Lian
2025-04-29 12:05:44 -04:00
committed by GitHub
parent 80b4edb4a7
commit 6565ae85d8
2 changed files with 23 additions and 7 deletions

View File

@@ -152,6 +152,12 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager.register(plugin_name) plugin_manager.register(plugin_name)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
""" """
Loads the `axolotl` configuration stored at `config`, validates it, and performs Loads the `axolotl` configuration stored at `config`, validates it, and performs
@@ -213,5 +219,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg) setup_comet_env_vars(cfg)
plugin_set_cfg(cfg)
return cfg return cfg

View File

@@ -270,6 +270,7 @@ class PluginManager:
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None _instance = None
_cfg = None
def __new__(cls): def __new__(cls):
""" """
@@ -277,7 +278,9 @@ class PluginManager:
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls) cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict() cls._instance.plugins: OrderedDict[str, BasePlugin] = (
collections.OrderedDict()
)
return cls._instance return cls._instance
@staticmethod @staticmethod
@@ -290,6 +293,14 @@ class PluginManager:
PluginManager() PluginManager()
return PluginManager._instance # type: ignore return PluginManager._instance # type: ignore
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, cfg):
self._cfg = cfg
def register(self, plugin_name: str): def register(self, plugin_name: str):
""" """
Registers a new plugin by its name. Registers a new plugin by its name.
@@ -409,29 +420,27 @@ class PluginManager:
return trainer_cls return trainer_cls
return None return None
def create_optimizer(self, cfg, trainer): def create_optimizer(self, trainer):
""" """
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters: Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
Returns: Returns:
object: The created optimizer, or None if none was found. object: The created optimizer, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer) optimizer = plugin.create_optimizer(self.cfg, trainer)
if optimizer is not None: if optimizer is not None:
return optimizer return optimizer
return None return None
def create_lr_scheduler(self, cfg, trainer, optimizer): def create_lr_scheduler(self, trainer, optimizer):
""" """
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters: Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training. trainer (object): The trainer object for training.
optimizer (object): The optimizer for training. optimizer (object): The optimizer for training.
@@ -439,7 +448,7 @@ class PluginManager:
object: The created learning rate scheduler, or None if none was found. object: The created learning rate scheduler, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer)
if scheduler is not None: if scheduler is not None:
return scheduler return scheduler
return None return None