diff --git a/scripts/finetune.py b/scripts/finetune.py index 33d3f1a51..090f8099e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -369,22 +369,18 @@ def train( ) model.config.use_cache = False - old_state_dict = model.state_dict - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) - ).__get__(model, type(model)) - if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) + # go ahead and presave, so we have the adapter config available to inspect + lora_config.save_pretrained(cfg.output_dir) + # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model signal.signal( signal.SIGINT, lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)), ) - # go ahead and presave the adapter config - lora_config.save_pretrained(cfg.output_dir) trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) model.save_pretrained(cfg.output_dir)