Experimental ReLoRA (+qlora) implementation

This commit is contained in:
Charles Goddard
2023-07-24 09:53:27 -07:00
committed by Wing Lian
parent 918f1b0dfb
commit b57238ecec
6 changed files with 375 additions and 1 deletions

View File

@@ -371,8 +371,14 @@ def train(
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
if cfg.adapter == "lora" and cfg.relora_steps:
model = model.merge_and_unload()
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__":
fire.Fire(train)