register rala

This commit is contained in:
Wing Lian
2025-01-15 23:21:22 -05:00
parent d28fee7609
commit 36b71f34d7
3 changed files with 11 additions and 11 deletions

View File

@@ -48,12 +48,12 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
None
Returns:
None
@@ -287,6 +287,7 @@ class PluginManager:
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
plugin.register()
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")

View File

@@ -16,6 +16,5 @@ class RalaPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs"
def pre_model_load(self, cfg):
if cfg.rala_attention:
register_rala_model()
def register(self):
register_rala_model()

View File

@@ -97,10 +97,10 @@ def convert_to_rala(
model.config.architectures = [
"LlamaRalaForCausalLM",
]
model.config.model_type = "llama_rala"
model.config.auto_map = {
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
"AutoModel": "llama.modeling_rala.LlamaRalaModel",
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
}
model.config.model_type = "llama-rala"
# model.config.auto_map = {
# "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
# "AutoModel": "llama.modeling_rala.LlamaRalaModel",
# "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
# }
return model