From 79159b4871c6aae6467cdeb183fe88f4c2d9a0bf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 2 Feb 2025 23:47:00 -0500 Subject: [PATCH] support custom module prompt strategy for rl --- src/axolotl/prompt_strategies/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index cddb3d0e1..ae5f85450 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -14,6 +14,16 @@ def load(strategy, cfg, module_base=None, **kwargs): strategy = strategy + ".default" load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) + if len(strategy.split(".")) > 1: + try: + importlib.import_module( + "." + strategy.split(".")[-1], + ".".join(strategy.split(".")[:-1]), + ) + module_base = ".".join(strategy.split(".")[:-1]) + strategy = strategy.split(".")[-1] + except ModuleNotFoundError: + pass mod = importlib.import_module(f".{strategy}", module_base) func = getattr(mod, load_fn) return func(cfg, **kwargs)