Fix saving logic

This commit is contained in:
Charles Goddard
2023-07-24 22:14:16 -07:00
committed by Wing Lian
parent 64a8e04430
commit bbf88b02c1

View File

@@ -78,7 +78,8 @@ class ReLoRACallback(TrainerCallback):
) )
reset_optimizer(optimizer) reset_optimizer(optimizer)
self.last_full_model = checkpoint_folder if self.quantised:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1 self.num_lora_restarts += 1
return control return control
@@ -95,8 +96,13 @@ class ReLoRACallback(TrainerCallback):
args.output_dir, args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
) )
if state.global_step >= self.relora_steps: if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantised and self.last_full_model != checkpoint_folder: if self.quantised and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}") LOG.info(f"moving last full parameter save to {checkpoint_folder}")
chunks = glob.glob( chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors" f"{self.last_full_model}/model*.safetensors"
@@ -249,18 +255,23 @@ def merge_and_save(
for adapter_name in target.lora_embedding_A: for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name) target.reset_lora_parameters(adapter_name)
old_dev = target.weight.device
if isinstance(target, peft.tuners.lora.Linear4bit): if isinstance(target, peft.tuners.lora.Linear4bit):
target.weight = bnb.nn.Params4bit( target.weight = (
new_weight, bnb.nn.Params4bit(
requires_grad=False, new_weight,
compress_statistics=target.weight.compress_statistics, requires_grad=False,
quant_type=target.weight.quant_type, compress_statistics=target.weight.compress_statistics,
).to(old_dev) quant_type=target.weight.quant_type,
)
.cuda(None)
.to(old_dev)
)
elif isinstance(target, peft.tuners.lora.Linear8bitLt): elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params( target.weight = (
new_weight, requires_grad=False bnb.nn.Int8Params(new_weight, requires_grad=False)
).to(old_dev) .cuda(None)
.to(old_dev)
)
else: else:
target.weight.data = new_weight.to(old_dev) target.weight.data = new_weight.to(old_dev)
@@ -275,7 +286,10 @@ def merge_and_save(
out_shard_name.replace("pytorch_model", "model").rstrip(".bin") out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors" + ".safetensors"
) )
st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn)
del out_tensors del out_tensors
torch.cuda.empty_cache() torch.cuda.empty_cache()