Fix saving logic

This commit is contained in:
Charles Goddard
2023-07-24 22:14:16 -07:00
committed by Wing Lian
parent 64a8e04430
commit bbf88b02c1

View File

@@ -78,7 +78,8 @@ class ReLoRACallback(TrainerCallback):
)
reset_optimizer(optimizer)
self.last_full_model = checkpoint_folder
if self.quantised:
self.last_full_model = checkpoint_folder
self.num_lora_restarts += 1
return control
@@ -95,8 +96,13 @@ class ReLoRACallback(TrainerCallback):
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
)
if state.global_step >= self.relora_steps:
if (
state.global_step >= self.relora_steps
and state.global_step % self.relora_steps != 0
):
if self.quantised and self.last_full_model != checkpoint_folder:
# ensure the latest full parameter save is in the latest checkpoint
# folder, so that automatic pruning of checkpoints does not remove it
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
chunks = glob.glob(
f"{self.last_full_model}/model*.safetensors"
@@ -249,18 +255,23 @@ def merge_and_save(
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
old_dev = target.weight.device
if isinstance(target, peft.tuners.lora.Linear4bit):
target.weight = bnb.nn.Params4bit(
new_weight,
requires_grad=False,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
).to(old_dev)
target.weight = (
bnb.nn.Params4bit(
new_weight,
requires_grad=False,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
)
.cuda(None)
.to(old_dev)
)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(
new_weight, requires_grad=False
).to(old_dev)
target.weight = (
bnb.nn.Int8Params(new_weight, requires_grad=False)
.cuda(None)
.to(old_dev)
)
else:
target.weight.data = new_weight.to(old_dev)
@@ -275,7 +286,10 @@ def merge_and_save(
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
+ ".safetensors"
)
st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
shard_fn = str(Path(model_dst) / out_shard_name)
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn)
del out_tensors
torch.cuda.empty_cache()