register rala
This commit is contained in:
@@ -48,12 +48,12 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
def register(self): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -287,6 +287,7 @@ class PluginManager:
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
plugin.register()
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
|
||||
@@ -16,6 +16,5 @@ class RalaPlugin(BasePlugin):
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.rala.args.RalaArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.rala_attention:
|
||||
register_rala_model()
|
||||
def register(self):
|
||||
register_rala_model()
|
||||
|
||||
@@ -97,10 +97,10 @@ def convert_to_rala(
|
||||
model.config.architectures = [
|
||||
"LlamaRalaForCausalLM",
|
||||
]
|
||||
model.config.model_type = "llama_rala"
|
||||
model.config.auto_map = {
|
||||
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
|
||||
"AutoModel": "llama.modeling_rala.LlamaRalaModel",
|
||||
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
|
||||
}
|
||||
model.config.model_type = "llama-rala"
|
||||
# model.config.auto_map = {
|
||||
# "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
|
||||
# "AutoModel": "llama.modeling_rala.LlamaRalaModel",
|
||||
# "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
|
||||
# }
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user