fix: config to allow optional input

This commit is contained in:
NanoCode012
2025-02-05 15:52:30 +07:00
parent 2bc7833a4e
commit 4cc60df876

View File

@@ -14,6 +14,8 @@
# limitations under the License.
"""Linear LLaMA model configuration"""
from typing import Optional
from transformers import LlamaConfig
@@ -26,6 +28,8 @@ class LinearLlamaConfig(LlamaConfig):
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
Expected contents:
`attention_type` (str):
The type of attention to convert to.
`feature_map` (`str`):
The type of feature map to use for linear attention.
`feature_map_kwargs` (`dict`):
@@ -57,11 +61,11 @@ class LinearLlamaConfig(LlamaConfig):
model_type = "linear_llama"
def __init__(self, attention_config: dict, **kwargs):
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
# Set default attention config if none provided
self.attention_config = attention_config
self.attention_config = attention_config or {"attention_type": "softmax"}
@classmethod
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):