diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d1..0f1928fe5 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -48,9 +48,9 @@ class BasePlugin: Initializes the BasePlugin. """ - def register(self, cfg): # pylint: disable=unused-argument + def register(self): # pylint: disable=unused-argument """ - Registers the plugin with the given configuration. + Registers the plugin Parameters: cfg (dict): The configuration for the plugin. @@ -274,6 +274,7 @@ class PluginManager: try: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin + plugin.register() except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py index 8822fc4a0..75681306a 100644 --- a/src/axolotl/integrations/rrt/__init__.py +++ b/src/axolotl/integrations/rrt/__init__.py @@ -4,8 +4,12 @@ Axolotl Plugin for Relaxed Recursive Transformers import logging +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + from axolotl.integrations.base import BasePlugin from axolotl.integrations.rrt.modeling import register_rrt_model +from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \ + RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM LOG = logging.getLogger(__name__) @@ -23,3 +27,16 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin): "Registering Relaxed Recursive Transformers modeling with transformers" ) register_rrt_model() + + +def register_rrt_model(): + """ + Register Relaxed Recursive Transformers model with transformers + """ + + # Register configs + AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig) + + # Register models + AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) + AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM) diff --git a/src/axolotl/integrations/rrt/args.py b/src/axolotl/integrations/rrt/args.py index e69de29bb..18cf0d360 100644 --- a/src/axolotl/integrations/rrt/args.py +++ b/src/axolotl/integrations/rrt/args.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class RelaxedRecursiveTransformerArgs(BaseModel): + """ + Arguments pertaining to the Relaxed Recursive Transformer model. + """ + ... diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 38d994c93..8d90db484 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -47,7 +47,7 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], rrt_avg_model_state_dict = {} # iterate over all parameter weights in the model shards - for key, weight, layer_idx in iter_parameter_weights(model_path): + for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): # get the matching module name in modules_to_recurse for the current parameter key matched_module_name = next( (module for module in modules_to_recurse if module in key), @@ -140,7 +140,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re rrt_avg_model_state_dict = {} # iterate over all parameter weights in the model shards - for key, weight, layer_idx in iter_parameter_weights(model_path): + for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): # get the matching module name in modules_to_recurse for the current parameter key matched_module_name = next( (module for module in modules_to_recurse if module in key), @@ -260,9 +260,11 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= # create a new state_dict to store the RRT model weights rrt_model_state_dict = {} + logger.info(f"Calculating average recursive weights...") for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers): rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() + logger.info(f"Calculating decomposed lora diff...") # now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff rrt_lora_state_dict = {} for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers): @@ -277,4 +279,10 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= if __name__ == "__main__": # meta-llama/Llama-3.2-1B has 16 hidden layers - convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512) + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device) diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 4b64c7bf4..6b00f2ec3 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -20,6 +20,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig): Configuration for Relaxed Recursive Llama. """ + model_type = "llama-rrt" recurse_layers: int = 4 rank: int alpha: int