diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index 3258effa7..12b436a0c 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -41,7 +41,10 @@ def load(strategy, cfg, module_base=None, **kwargs): module_name = strategy.split(".")[-2] mod = import_from_path(module_name, file_path) func = getattr(mod, load_fn) - return func(cfg, **kwargs) + if func is None: + strategy = "." + ".".join(strategy.split(".")[:-1]) + else: + return func(cfg, **kwargs) except FileNotFoundError: strategy = "." + ".".join(strategy.split(".")[:-1])