diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index 12b436a0c..e142b5ba2 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -23,36 +23,44 @@ def import_from_path(module_name, file_path): def load(strategy, cfg, module_base=None, **kwargs): - try: - if len(strategy.split(".")) == 1: - strategy = strategy + ".default" - load_fn = strategy.split(".")[-1] - if len(strategy.split(".")) > 1: - try: - importlib.import_module( - strategy.split(".")[-2], - ".".join(strategy.split(".")[:-2]), - ) - module_base = ".".join(strategy.split(".")[:-2]) - strategy = strategy.split(".")[-2] - except ModuleNotFoundError: - try: - file_path = "/".join(strategy.split(".")[:-1]) + ".py" - module_name = strategy.split(".")[-2] - mod = import_from_path(module_name, file_path) - func = getattr(mod, load_fn) - if func is None: - strategy = "." + ".".join(strategy.split(".")[:-1]) - else: - return func(cfg, **kwargs) - except FileNotFoundError: - strategy = "." + ".".join(strategy.split(".")[:-1]) + if len(strategy.split(".")) == 1: + strategy = strategy + ".default" + load_fn = strategy.split(".")[-1] + func = None + if len(strategy.split(".")) > 1: + try: + mod = importlib.import_module( + strategy.split(".")[-2], + ".".join(strategy.split(".")[:-2]), + ) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except ModuleNotFoundError: + pass - else: - strategy = "." + ".".join(strategy.split(".")[:-1]) + try: + mod = importlib.import_module( + "." + ".".join(strategy.split(".")[:-1]), module_base + ) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except ModuleNotFoundError: + pass + + try: + file_path = "/".join(strategy.split(".")[:-1]) + ".py" + module_name = strategy.split(".")[-2] + mod = import_from_path(module_name, file_path) + func = getattr(mod, load_fn) + if func is not None: + return func(cfg, **kwargs) + except FileNotFoundError: + pass + else: + strategy = "." + ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(strategy, module_base) func = getattr(mod, load_fn) return func(cfg, **kwargs) - except Exception: # pylint: disable=broad-exception-caught - LOG.warning(f"unable to load strategy {strategy}") - return None + + LOG.warning(f"unable to load strategy {strategy}") + return func