Fix saving logic
This commit is contained in:
committed by
Wing Lian
parent
64a8e04430
commit
bbf88b02c1
@@ -78,7 +78,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
self.last_full_model = checkpoint_folder
|
||||
if self.quantised:
|
||||
self.last_full_model = checkpoint_folder
|
||||
self.num_lora_restarts += 1
|
||||
|
||||
return control
|
||||
@@ -95,8 +96,13 @@ class ReLoRACallback(TrainerCallback):
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
)
|
||||
if state.global_step >= self.relora_steps:
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
):
|
||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||
# ensure the latest full parameter save is in the latest checkpoint
|
||||
# folder, so that automatic pruning of checkpoints does not remove it
|
||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||
chunks = glob.glob(
|
||||
f"{self.last_full_model}/model*.safetensors"
|
||||
@@ -249,18 +255,23 @@ def merge_and_save(
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
old_dev = target.weight.device
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
target.weight = bnb.nn.Params4bit(
|
||||
new_weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=target.weight.compress_statistics,
|
||||
quant_type=target.weight.quant_type,
|
||||
).to(old_dev)
|
||||
target.weight = (
|
||||
bnb.nn.Params4bit(
|
||||
new_weight,
|
||||
requires_grad=False,
|
||||
compress_statistics=target.weight.compress_statistics,
|
||||
quant_type=target.weight.quant_type,
|
||||
)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight = bnb.nn.Int8Params(
|
||||
new_weight, requires_grad=False
|
||||
).to(old_dev)
|
||||
target.weight = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||
.cuda(None)
|
||||
.to(old_dev)
|
||||
)
|
||||
else:
|
||||
target.weight.data = new_weight.to(old_dev)
|
||||
|
||||
@@ -275,7 +286,10 @@ def merge_and_save(
|
||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||
+ ".safetensors"
|
||||
)
|
||||
st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
|
||||
|
||||
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn)
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user