auto modeling for rrt
This commit is contained in:
@@ -48,9 +48,9 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
def register(self): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
Registers the plugin
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
@@ -274,6 +274,7 @@ class PluginManager:
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
plugin.register()
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
|
||||
@@ -4,8 +4,12 @@ Axolotl Plugin for Relaxed Recursive Transformers
|
||||
|
||||
import logging
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.rrt.modeling import register_rrt_model
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
|
||||
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,3 +27,16 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
|
||||
"Registering Relaxed Recursive Transformers modeling with transformers"
|
||||
)
|
||||
register_rrt_model()
|
||||
|
||||
|
||||
def register_rrt_model():
|
||||
"""
|
||||
Register Relaxed Recursive Transformers model with transformers
|
||||
"""
|
||||
|
||||
# Register configs
|
||||
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RelaxedRecursiveTransformerArgs(BaseModel):
|
||||
"""
|
||||
Arguments pertaining to the Relaxed Recursive Transformer model.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -47,7 +47,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
||||
rrt_avg_model_state_dict = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key),
|
||||
@@ -140,7 +140,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
||||
rrt_avg_model_state_dict = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key),
|
||||
@@ -260,9 +260,11 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
||||
# create a new state_dict to store the RRT model weights
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
logger.info(f"Calculating average recursive weights...")
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
logger.info(f"Calculating decomposed lora diff...")
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
||||
@@ -277,4 +279,10 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
||||
|
||||
if __name__ == "__main__":
|
||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512)
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
|
||||
|
||||
@@ -20,6 +20,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
Configuration for Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
model_type = "llama-rrt"
|
||||
recurse_layers: int = 4
|
||||
rank: int
|
||||
alpha: int
|
||||
|
||||
Reference in New Issue
Block a user