diff --git a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py index 004056871..7c38b0f5a 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py @@ -64,11 +64,12 @@ class LinearLlamaConfig(LlamaConfig): def __init__(self, attention_config: Optional[dict] = None, **kwargs): super().__init__(**kwargs) - # self.auto_map = { - # "AutoConfig": "configuration_linear_llama.LinearLlamaConfig", - # "AutoModel": "modeling_linear_llama.LinearLlamaModel", - # "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM", - # } + # Set auto_map + self.auto_map = { + "AutoConfig": "configuration_linear_llama.LinearLlamaConfig", + "AutoModel": "modeling_linear_llama.LinearLlamaModel", + "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM", + } # Set default attention config if none provided self.attention_config = attention_config or {"attention_type": "softmax"}