From 3bc568eb27e0f67706afc0a866b9b070c4f15d19 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 27 Dec 2024 11:17:52 -0500 Subject: [PATCH] adding registration function --- .../diff_transformer/modeling_diff_attn.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index b84dfcd16..cf75ff37a 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, @@ -172,3 +173,12 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): new_model.lm_head.load_state_dict(model.lm_head.state_dict()) return new_model + + +def register_diff_attn(): + # Register configs + AutoConfig.register("llama-differential", LlamaDifferentialConfig) + + # Register models + AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel) + AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)