diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index afd8d01fa..3258effa7 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -36,7 +36,15 @@ def load(strategy, cfg, module_base=None, **kwargs): module_base = ".".join(strategy.split(".")[:-2]) strategy = strategy.split(".")[-2] except ModuleNotFoundError: - strategy = "." + ".".join(strategy.split(".")[:-1]) + try: + file_path = "/".join(strategy.split(".")[:-1]) + ".py" + module_name = strategy.split(".")[-2] + mod = import_from_path(module_name, file_path) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except FileNotFoundError: + strategy = "." + ".".join(strategy.split(".")[:-1]) + else: strategy = "." + ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(strategy, module_base)