diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index b84dfcd16..cf75ff37a 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, @@ -172,3 +173,12 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): new_model.lm_head.load_state_dict(model.lm_head.state_dict()) return new_model + + +def register_diff_attn(): + # Register configs + AutoConfig.register("llama-differential", LlamaDifferentialConfig) + + # Register models + AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel) + AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)