register rala

This commit is contained in:
Wing Lian
2025-01-15 23:21:22 -05:00
parent d28fee7609
commit 36b71f34d7
3 changed files with 11 additions and 11 deletions

View File

@@ -48,12 +48,12 @@ class BasePlugin:
Initializes the BasePlugin. Initializes the BasePlugin.
""" """
def register(self, cfg): # pylint: disable=unused-argument def register(self): # pylint: disable=unused-argument
""" """
Registers the plugin with the given configuration. Registers the plugin with the given configuration.
Parameters: Parameters:
cfg (dict): The configuration for the plugin. None
Returns: Returns:
None None
@@ -287,6 +287,7 @@ class PluginManager:
try: try:
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins[plugin_name] = plugin
plugin.register()
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") logging.error(f"Failed to load plugin: {plugin_name}")

View File

@@ -16,6 +16,5 @@ class RalaPlugin(BasePlugin):
def get_input_args(self): def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs" return "axolotl.integrations.rala.args.RalaArgs"
def pre_model_load(self, cfg): def register(self):
if cfg.rala_attention: register_rala_model()
register_rala_model()

View File

@@ -97,10 +97,10 @@ def convert_to_rala(
model.config.architectures = [ model.config.architectures = [
"LlamaRalaForCausalLM", "LlamaRalaForCausalLM",
] ]
model.config.model_type = "llama_rala" model.config.model_type = "llama-rala"
model.config.auto_map = { # model.config.auto_map = {
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig", # "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
"AutoModel": "llama.modeling_rala.LlamaRalaModel", # "AutoModel": "llama.modeling_rala.LlamaRalaModel",
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM", # "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
} # }
return model return model