diff --git a/src/axolotl/cli/convert_linear_attention.py b/src/axolotl/cli/convert_linear_attention.py index dae7d7bb0..5045b3545 100644 --- a/src/axolotl/cli/convert_linear_attention.py +++ b/src/axolotl/cli/convert_linear_attention.py @@ -49,12 +49,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: for p in model.parameters(): p.requires_grad = False - # load config - base_config = load_model_config(cfg) - # convert to linear llama linear_llama_config = LinearLlamaConfig.from_llama( - base_config, cfg.attention_config + model.config, cfg.attention_config ) model = LinearLlamaForCausalLM.from_llama( model, config=linear_llama_config, train_attention=True