adding registration function

This commit is contained in:
Dan Saunders
2024-12-27 11:17:52 -05:00
committed by Dan Saunders
parent eb6611d55f
commit 3bc568eb27

View File

@@ -3,6 +3,7 @@
from typing import Optional
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
@@ -172,3 +173,12 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
def register_diff_attn():
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)