From ec52561a0c15ba6d3b6dae19b965516310a9c399 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Sat, 15 Mar 2025 20:25:53 -0400 Subject: [PATCH] import from filepath if can't import_module --- src/axolotl/prompt_strategies/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index afd8d01fa..3258effa7 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -36,7 +36,15 @@ def load(strategy, cfg, module_base=None, **kwargs): module_base = ".".join(strategy.split(".")[:-2]) strategy = strategy.split(".")[-2] except ModuleNotFoundError: - strategy = "." + ".".join(strategy.split(".")[:-1]) + try: + file_path = "/".join(strategy.split(".")[:-1]) + ".py" + module_name = strategy.split(".")[-2] + mod = import_from_path(module_name, file_path) + func = getattr(mod, load_fn) + return func(cfg, **kwargs) + except FileNotFoundError: + strategy = "." + ".".join(strategy.split(".")[:-1]) + else: strategy = "." + ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(strategy, module_base)