adding registration function
This commit is contained in:
committed by
Dan Saunders
parent
eb6611d55f
commit
3bc568eb27
@@ -3,6 +3,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
@@ -172,3 +173,12 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
|||||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
|
def register_diff_attn():
|
||||||
|
# Register configs
|
||||||
|
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
|
||||||
|
|
||||||
|
# Register models
|
||||||
|
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
|
||||||
|
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
|
||||||
|
|||||||
Reference in New Issue
Block a user