fix: use existing model config

This commit is contained in:
NanoCode012
2025-02-06 00:12:14 +07:00
parent c15ea6b956
commit caa49a9d7d

View File

@@ -49,12 +49,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
for p in model.parameters():
p.requires_grad = False
# load config
base_config = load_model_config(cfg)
# convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama(
base_config, cfg.attention_config
model.config, cfg.attention_config
)
model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True