Experimental ReLoRA (+qlora) implementation
This commit is contained in:
committed by
Wing Lian
parent
918f1b0dfb
commit
b57238ecec
@@ -371,8 +371,14 @@ def train(
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.adapter == "lora" and cfg.relora_steps:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(train)
|
||||
|
||||
Reference in New Issue
Block a user