auto modeling for rrt

This commit is contained in:
Wing Lian
2025-01-20 11:59:23 -05:00
parent b439ed3345
commit 82005f8eeb
5 changed files with 40 additions and 5 deletions

View File

@@ -48,9 +48,9 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Registers the plugin
Parameters:
cfg (dict): The configuration for the plugin.
@@ -274,6 +274,7 @@ class PluginManager:
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
plugin.register()
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")

View File

@@ -4,8 +4,12 @@ Axolotl Plugin for Relaxed Recursive Transformers
import logging
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
LOG = logging.getLogger(__name__)
@@ -23,3 +27,16 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
"Registering Relaxed Recursive Transformers modeling with transformers"
)
register_rrt_model()
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class RelaxedRecursiveTransformerArgs(BaseModel):
"""
Arguments pertaining to the Relaxed Recursive Transformer model.
"""
...

View File

@@ -47,7 +47,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
rrt_avg_model_state_dict = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path):
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
@@ -140,7 +140,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
rrt_avg_model_state_dict = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path):
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
@@ -260,9 +260,11 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
logger.info(f"Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info(f"Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
@@ -277,4 +279,10 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512)
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)

View File

@@ -20,6 +20,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
Configuration for Relaxed Recursive Llama.
"""
model_type = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int