adding registration function
This commit is contained in:
committed by
Dan Saunders
parent
eb6611d55f
commit
3bc568eb27
@@ -3,6 +3,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
@@ -172,3 +173,12 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||
|
||||
return new_model
|
||||
|
||||
|
||||
def register_diff_attn():
|
||||
# Register configs
|
||||
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
|
||||
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
|
||||
|
||||
Reference in New Issue
Block a user