diff --git a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py index 2eb0ec8a4..3997c134b 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py @@ -14,6 +14,8 @@ # limitations under the License. """Linear LLaMA model configuration""" +from typing import Optional + from transformers import LlamaConfig @@ -26,6 +28,8 @@ class LinearLlamaConfig(LlamaConfig): attention_config (`dict`): Dictionary containing the configuration for linear attention mechanism. Expected contents: + `attention_type` (str): + The type of attention to convert to. `feature_map` (`str`): The type of feature map to use for linear attention. `feature_map_kwargs` (`dict`): @@ -57,11 +61,11 @@ class LinearLlamaConfig(LlamaConfig): model_type = "linear_llama" - def __init__(self, attention_config: dict, **kwargs): + def __init__(self, attention_config: Optional[dict] = None, **kwargs): super().__init__(**kwargs) # Set default attention config if none provided - self.attention_config = attention_config + self.attention_config = attention_config or {"attention_type": "softmax"} @classmethod def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):