fix: config to allow optional input
This commit is contained in:
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Linear LLaMA model configuration"""
|
"""Linear LLaMA model configuration"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -26,6 +28,8 @@ class LinearLlamaConfig(LlamaConfig):
|
|||||||
attention_config (`dict`):
|
attention_config (`dict`):
|
||||||
Dictionary containing the configuration for linear attention mechanism.
|
Dictionary containing the configuration for linear attention mechanism.
|
||||||
Expected contents:
|
Expected contents:
|
||||||
|
`attention_type` (str):
|
||||||
|
The type of attention to convert to.
|
||||||
`feature_map` (`str`):
|
`feature_map` (`str`):
|
||||||
The type of feature map to use for linear attention.
|
The type of feature map to use for linear attention.
|
||||||
`feature_map_kwargs` (`dict`):
|
`feature_map_kwargs` (`dict`):
|
||||||
@@ -57,11 +61,11 @@ class LinearLlamaConfig(LlamaConfig):
|
|||||||
|
|
||||||
model_type = "linear_llama"
|
model_type = "linear_llama"
|
||||||
|
|
||||||
def __init__(self, attention_config: dict, **kwargs):
|
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
# Set default attention config if none provided
|
# Set default attention config if none provided
|
||||||
self.attention_config = attention_config
|
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||||
|
|||||||
Reference in New Issue
Block a user