update
This commit is contained in:
@@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
||||
class LlamaDifferentialConfig(LlamaConfig):
|
||||
"""Configuration class for Differential LLaMA model."""
|
||||
|
||||
model_type = "llama-differential"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_heads: bool = False,
|
||||
@@ -213,6 +215,9 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
"""LlamaForCausalLM with differential attention."""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LlamaDifferentialModel(config)
|
||||
|
||||
Reference in New Issue
Block a user