diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index b067e6b64..50c06b988 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -264,8 +264,16 @@ def save_state_dict_to_safetensors(state_dict, save_directory): def convert_llama_to_rrt( - model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps" + model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None ): + if not device: + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + else: + device = "cpu" + modules_to_recurse = [ "self_attn.q_proj", "self_attn.k_proj", @@ -329,17 +337,10 @@ def convert_llama_to_rrt( if __name__ == "__main__": # meta-llama/Llama-3.2-1B has 16 hidden layers - if torch.backends.mps.is_available(): - device = "mps" - elif torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" convert_llama_to_rrt( "meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, - device=device, ) diff --git a/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py new file mode 100644 index 000000000..88e743088 --- /dev/null +++ b/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py @@ -0,0 +1,13 @@ +from transformers import LlamaConfig + + +class RelaxedRecursiveLlamaConfig(LlamaConfig): + """ + Configuration for Relaxed Recursive Llama. + """ + + model_type = "llama-rrt" + recurse_layers: int = 4 + rank: int + alpha: int + use_dora: bool = True diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 6b00f2ec3..444dc923f 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -12,20 +12,10 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear +from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig logger = logging.getLogger(__name__) -class RelaxedRecursiveLlamaConfig(LlamaConfig): - """ - Configuration for Relaxed Recursive Llama. - """ - - model_type = "llama-rrt" - recurse_layers: int = 4 - rank: int - alpha: int - use_dora: bool = True - class RelaxedRecursiveLlamaMLP(nn.Module): def __init__(self, config: RelaxedRecursiveLlamaConfig):