register rala
This commit is contained in:
@@ -48,12 +48,12 @@ class BasePlugin:
|
|||||||
Initializes the BasePlugin.
|
Initializes the BasePlugin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register(self, cfg): # pylint: disable=unused-argument
|
def register(self): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Registers the plugin with the given configuration.
|
Registers the plugin with the given configuration.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
@@ -287,6 +287,7 @@ class PluginManager:
|
|||||||
try:
|
try:
|
||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins[plugin_name] = plugin
|
||||||
|
plugin.register()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,5 @@ class RalaPlugin(BasePlugin):
|
|||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.rala.args.RalaArgs"
|
return "axolotl.integrations.rala.args.RalaArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def register(self):
|
||||||
if cfg.rala_attention:
|
register_rala_model()
|
||||||
register_rala_model()
|
|
||||||
|
|||||||
@@ -97,10 +97,10 @@ def convert_to_rala(
|
|||||||
model.config.architectures = [
|
model.config.architectures = [
|
||||||
"LlamaRalaForCausalLM",
|
"LlamaRalaForCausalLM",
|
||||||
]
|
]
|
||||||
model.config.model_type = "llama_rala"
|
model.config.model_type = "llama-rala"
|
||||||
model.config.auto_map = {
|
# model.config.auto_map = {
|
||||||
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
|
# "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
|
||||||
"AutoModel": "llama.modeling_rala.LlamaRalaModel",
|
# "AutoModel": "llama.modeling_rala.LlamaRalaModel",
|
||||||
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
|
# "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
|
||||||
}
|
# }
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user