diff --git a/src/axolotl/cli/convert_linear_attention.py b/src/axolotl/cli/convert_linear_attention.py index dee2de05c..cacde2b56 100644 --- a/src/axolotl/cli/convert_linear_attention.py +++ b/src/axolotl/cli/convert_linear_attention.py @@ -45,6 +45,10 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: # load model model, tokenizer = load_model_and_tokenizer(cfg=cfg) + # freeze model + for p in model.parameters(): + p.requires_grad = False + # load config base_config = load_model_config(cfg) @@ -56,6 +60,18 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: model, config=linear_llama_config, train_attention=True ) + # register model + from transformers import AutoConfig, AutoModel + + AutoConfig.register("linear_llama", LinearLlamaConfig) + AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM) + + # set save_path, save tokenizer and model config. + save_path = str(os.path.join(cfg.output_dir, "distilled")) + tokenizer.save_pretrained(save_path) + if hasattr(model, "config"): + model.config.save_pretrained(save_path) + # Get datasets dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train_dataset = dataset_meta.train_dataset @@ -86,14 +102,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: # NOTE: If in peft mode, consider whether to auto-merge # save model - save_path = str(os.path.join(cfg.output_dir, "distilled")) - tokenizer.save_pretrained(save_path) - if hasattr(model, "config"): - model.config.save_pretrained(save_path) - safe_serialization = cfg.save_safetensors is True # NOTE: may need to consider other ways of saving due to multi-gpu etc - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + model.save_pretrained(save_path, safe_serialization=safe_serialization) # cleanup plugin_manager = PluginManager.get_instance()