auto modeling for rrt
This commit is contained in:
@@ -48,9 +48,9 @@ class BasePlugin:
|
|||||||
Initializes the BasePlugin.
|
Initializes the BasePlugin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def register(self, cfg): # pylint: disable=unused-argument
|
def register(self): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Registers the plugin with the given configuration.
|
Registers the plugin
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
@@ -274,6 +274,7 @@ class PluginManager:
|
|||||||
try:
|
try:
|
||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins[plugin_name] = plugin
|
||||||
|
plugin.register()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ Axolotl Plugin for Relaxed Recursive Transformers
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.integrations.rrt.modeling import register_rrt_model
|
from axolotl.integrations.rrt.modeling import register_rrt_model
|
||||||
|
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
|
||||||
|
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,3 +27,16 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
|
|||||||
"Registering Relaxed Recursive Transformers modeling with transformers"
|
"Registering Relaxed Recursive Transformers modeling with transformers"
|
||||||
)
|
)
|
||||||
register_rrt_model()
|
register_rrt_model()
|
||||||
|
|
||||||
|
|
||||||
|
def register_rrt_model():
|
||||||
|
"""
|
||||||
|
Register Relaxed Recursive Transformers model with transformers
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Register configs
|
||||||
|
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
||||||
|
|
||||||
|
# Register models
|
||||||
|
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||||
|
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RelaxedRecursiveTransformerArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Arguments pertaining to the Relaxed Recursive Transformer model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
|||||||
rrt_avg_model_state_dict = {}
|
rrt_avg_model_state_dict = {}
|
||||||
|
|
||||||
# iterate over all parameter weights in the model shards
|
# iterate over all parameter weights in the model shards
|
||||||
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||||
# get the matching module name in modules_to_recurse for the current parameter key
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
matched_module_name = next(
|
matched_module_name = next(
|
||||||
(module for module in modules_to_recurse if module in key),
|
(module for module in modules_to_recurse if module in key),
|
||||||
@@ -140,7 +140,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
|||||||
rrt_avg_model_state_dict = {}
|
rrt_avg_model_state_dict = {}
|
||||||
|
|
||||||
# iterate over all parameter weights in the model shards
|
# iterate over all parameter weights in the model shards
|
||||||
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||||
# get the matching module name in modules_to_recurse for the current parameter key
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
matched_module_name = next(
|
matched_module_name = next(
|
||||||
(module for module in modules_to_recurse if module in key),
|
(module for module in modules_to_recurse if module in key),
|
||||||
@@ -260,9 +260,11 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
|||||||
# create a new state_dict to store the RRT model weights
|
# create a new state_dict to store the RRT model weights
|
||||||
rrt_model_state_dict = {}
|
rrt_model_state_dict = {}
|
||||||
|
|
||||||
|
logger.info(f"Calculating average recursive weights...")
|
||||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
||||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
|
logger.info(f"Calculating decomposed lora diff...")
|
||||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||||
rrt_lora_state_dict = {}
|
rrt_lora_state_dict = {}
|
||||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
||||||
@@ -277,4 +279,10 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512)
|
if torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
|||||||
Configuration for Relaxed Recursive Llama.
|
Configuration for Relaxed Recursive Llama.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type = "llama-rrt"
|
||||||
recurse_layers: int = 4
|
recurse_layers: int = 4
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
|
|||||||
Reference in New Issue
Block a user