diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py new file mode 100644 index 000000000..8822fc4a0 --- /dev/null +++ b/src/axolotl/integrations/rrt/__init__.py @@ -0,0 +1,25 @@ +""" +Axolotl Plugin for Relaxed Recursive Transformers +""" + +import logging + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.rrt.modeling import register_rrt_model + +LOG = logging.getLogger(__name__) + + +class RelaxedRecursiveTransformerPlugin(BasePlugin): + """ + Plugin for Relaxed Recursive Transformers integration with Axolotl + """ + + def get_input_args(self): + return "axolotl.integrations.rrt.RelaxedRecursiveTransformerArgs" + + def register(self): + LOG.info( + "Registering Relaxed Recursive Transformers modeling with transformers" + ) + register_rrt_model() diff --git a/src/axolotl/integrations/rrt/modeling/__init__.py b/src/axolotl/integrations/rrt/modeling/__init__.py new file mode 100644 index 000000000..b30629bb4 --- /dev/null +++ b/src/axolotl/integrations/rrt/modeling/__init__.py @@ -0,0 +1,2 @@ +def register_rrt_model(): + pass diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py new file mode 100644 index 000000000..e69de29bb