update
This commit is contained in:
@@ -25,6 +25,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class LlamaDifferentialConfig(LlamaConfig):
|
class LlamaDifferentialConfig(LlamaConfig):
|
||||||
"""Configuration class for Differential LLaMA model."""
|
"""Configuration class for Differential LLaMA model."""
|
||||||
|
|
||||||
|
model_type = "llama-differential"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
split_heads: bool = False,
|
split_heads: bool = False,
|
||||||
@@ -213,6 +215,9 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||||
"""LlamaForCausalLM with differential attention."""
|
"""LlamaForCausalLM with differential attention."""
|
||||||
|
|
||||||
|
config_class = LlamaDifferentialConfig
|
||||||
|
base_model_prefix = "llama_differential"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = LlamaDifferentialModel(config)
|
self.model = LlamaDifferentialModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user