fix: config to allow optional input
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""Linear LLaMA model configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
@@ -26,6 +28,8 @@ class LinearLlamaConfig(LlamaConfig):
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
Expected contents:
|
||||
`attention_type` (str):
|
||||
The type of attention to convert to.
|
||||
`feature_map` (`str`):
|
||||
The type of feature map to use for linear attention.
|
||||
`feature_map_kwargs` (`dict`):
|
||||
@@ -57,11 +61,11 @@ class LinearLlamaConfig(LlamaConfig):
|
||||
|
||||
model_type = "linear_llama"
|
||||
|
||||
def __init__(self, attention_config: dict, **kwargs):
|
||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config
|
||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||
|
||||
@classmethod
|
||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||
|
||||
Reference in New Issue
Block a user