From 257231ac46d0cf0c7c0f977591b5f54ded9c59b9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 17 Jan 2025 08:48:45 -0500 Subject: [PATCH] wip rrt --- src/axolotl/integrations/rrt/__init__.py | 25 +++++++++++++++++++ .../integrations/rrt/modeling/__init__.py | 2 ++ .../rrt/modeling/modeling_rrt_llama.py | 0 3 files changed, 27 insertions(+) create mode 100644 src/axolotl/integrations/rrt/__init__.py create mode 100644 src/axolotl/integrations/rrt/modeling/__init__.py create mode 100644 src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py new file mode 100644 index 000000000..8822fc4a0 --- /dev/null +++ b/src/axolotl/integrations/rrt/__init__.py @@ -0,0 +1,25 @@ +""" +Axolotl Plugin for Relaxed Recursive Transformers +""" + +import logging + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.rrt.modeling import register_rrt_model + +LOG = logging.getLogger(__name__) + + +class RelaxedRecursiveTransformerPlugin(BasePlugin): + """ + Plugin for Relaxed Recursive Transformers integration with Axolotl + """ + + def get_input_args(self): + return "axolotl.integrations.rrt.RelaxedRecursiveTransformerArgs" + + def register(self): + LOG.info( + "Registering Relaxed Recursive Transformers modeling with transformers" + ) + register_rrt_model() diff --git a/src/axolotl/integrations/rrt/modeling/__init__.py b/src/axolotl/integrations/rrt/modeling/__init__.py new file mode 100644 index 000000000..b30629bb4 --- /dev/null +++ b/src/axolotl/integrations/rrt/modeling/__init__.py @@ -0,0 +1,2 @@ +def register_rrt_model(): + pass diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py new file mode 100644 index 000000000..e69de29bb