From 36b71f34d78a0c2b7be550f7867c6bb4920bfee5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Jan 2025 23:21:22 -0500 Subject: [PATCH] register rala --- src/axolotl/integrations/base.py | 5 +++-- src/axolotl/integrations/rala/__init__.py | 5 ++--- src/axolotl/integrations/rala/convert.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 6c422c0b8..b107684c8 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -48,12 +48,12 @@ class BasePlugin: Initializes the BasePlugin. """ - def register(self, cfg): # pylint: disable=unused-argument + def register(self): # pylint: disable=unused-argument """ Registers the plugin with the given configuration. Parameters: - cfg (dict): The configuration for the plugin. + None Returns: None @@ -287,6 +287,7 @@ class PluginManager: try: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin + plugin.register() except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py index 235d0cc22..ccdcc42b5 100644 --- a/src/axolotl/integrations/rala/__init__.py +++ b/src/axolotl/integrations/rala/__init__.py @@ -16,6 +16,5 @@ class RalaPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.rala.args.RalaArgs" - def pre_model_load(self, cfg): - if cfg.rala_attention: - register_rala_model() + def register(self): + register_rala_model() diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py index da57074bf..94562f6ae 100644 --- a/src/axolotl/integrations/rala/convert.py +++ b/src/axolotl/integrations/rala/convert.py @@ -97,10 +97,10 @@ def convert_to_rala( model.config.architectures = [ "LlamaRalaForCausalLM", ] - model.config.model_type = "llama_rala" - model.config.auto_map = { - "AutoConfig": "llama.configuration_rala.LlamaRalaConfig", - "AutoModel": "llama.modeling_rala.LlamaRalaModel", - "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM", - } + model.config.model_type = "llama-rala" + # model.config.auto_map = { + # "AutoConfig": "llama.configuration_rala.LlamaRalaConfig", + # "AutoModel": "llama.modeling_rala.LlamaRalaModel", + # "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM", + # } return model