diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 166a67670..64bf402b9 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -152,6 +152,12 @@ def prepare_plugins(cfg: DictDefault): plugin_manager.register(plugin_name) +def plugin_set_cfg(cfg: DictDefault): + if cfg.get("plugins"): + plugin_manager = PluginManager.get_instance() + plugin_manager.cfg = cfg + + def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault: """ Loads the `axolotl` configuration stored at `config`, validates it, and performs @@ -213,5 +219,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa setup_wandb_env_vars(cfg) setup_mlflow_env_vars(cfg) setup_comet_env_vars(cfg) + plugin_set_cfg(cfg) return cfg diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index cb65f96dd..7d6491478 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -270,6 +270,7 @@ class PluginManager: plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() _instance = None + _cfg = None def __new__(cls): """ @@ -277,7 +278,9 @@ class PluginManager: """ if cls._instance is None: cls._instance = super(PluginManager, cls).__new__(cls) - cls._instance.plugins = collections.OrderedDict() + cls._instance.plugins: OrderedDict[str, BasePlugin] = ( + collections.OrderedDict() + ) return cls._instance @staticmethod @@ -290,6 +293,14 @@ class PluginManager: PluginManager() return PluginManager._instance # type: ignore + @property + def cfg(self): + return self._cfg + + @cfg.setter + def cfg(self, cfg): + self._cfg = cfg + def register(self, plugin_name: str): """ Registers a new plugin by its name. @@ -409,29 +420,27 @@ class PluginManager: return trainer_cls return None - def create_optimizer(self, cfg, trainer): + def create_optimizer(self, trainer): """ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. Parameters: - cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training. Returns: object: The created optimizer, or None if none was found. """ for plugin in self.plugins.values(): - optimizer = plugin.create_optimizer(cfg, trainer) + optimizer = plugin.create_optimizer(self.cfg, trainer) if optimizer is not None: return optimizer return None - def create_lr_scheduler(self, cfg, trainer, optimizer): + def create_lr_scheduler(self, trainer, optimizer): """ Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. Parameters: - cfg (dict): The configuration for the plugins. trainer (object): The trainer object for training. optimizer (object): The optimizer for training. @@ -439,7 +448,7 @@ class PluginManager: object: The created learning rate scheduler, or None if none was found. """ for plugin in self.plugins.values(): - scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) + scheduler = plugin.create_lr_scheduler(self.cfg, trainer, optimizer) if scheduler is not None: return scheduler return None