This commit is contained in:
Dan Saunders
2024-12-27 21:29:37 +00:00
parent 78e0ec0aa5
commit e5fa842ff8

View File

@@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model."""
model_type = "llama-differential"
def __init__(
self,
split_heads: bool = False,
@@ -213,6 +215,9 @@ class LlamaDifferentialModel(LlamaModel):
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with differential attention."""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config):
super().__init__(config)
self.model = LlamaDifferentialModel(config)