Fix saving logic
This commit is contained in:
committed by
Wing Lian
parent
64a8e04430
commit
bbf88b02c1
@@ -78,7 +78,8 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
reset_optimizer(optimizer)
|
reset_optimizer(optimizer)
|
||||||
|
|
||||||
self.last_full_model = checkpoint_folder
|
if self.quantised:
|
||||||
|
self.last_full_model = checkpoint_folder
|
||||||
self.num_lora_restarts += 1
|
self.num_lora_restarts += 1
|
||||||
|
|
||||||
return control
|
return control
|
||||||
@@ -95,8 +96,13 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
args.output_dir,
|
args.output_dir,
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
)
|
)
|
||||||
if state.global_step >= self.relora_steps:
|
if (
|
||||||
|
state.global_step >= self.relora_steps
|
||||||
|
and state.global_step % self.relora_steps != 0
|
||||||
|
):
|
||||||
if self.quantised and self.last_full_model != checkpoint_folder:
|
if self.quantised and self.last_full_model != checkpoint_folder:
|
||||||
|
# ensure the latest full parameter save is in the latest checkpoint
|
||||||
|
# folder, so that automatic pruning of checkpoints does not remove it
|
||||||
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
||||||
chunks = glob.glob(
|
chunks = glob.glob(
|
||||||
f"{self.last_full_model}/model*.safetensors"
|
f"{self.last_full_model}/model*.safetensors"
|
||||||
@@ -249,18 +255,23 @@ def merge_and_save(
|
|||||||
for adapter_name in target.lora_embedding_A:
|
for adapter_name in target.lora_embedding_A:
|
||||||
target.reset_lora_parameters(adapter_name)
|
target.reset_lora_parameters(adapter_name)
|
||||||
|
|
||||||
old_dev = target.weight.device
|
|
||||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||||
target.weight = bnb.nn.Params4bit(
|
target.weight = (
|
||||||
new_weight,
|
bnb.nn.Params4bit(
|
||||||
requires_grad=False,
|
new_weight,
|
||||||
compress_statistics=target.weight.compress_statistics,
|
requires_grad=False,
|
||||||
quant_type=target.weight.quant_type,
|
compress_statistics=target.weight.compress_statistics,
|
||||||
).to(old_dev)
|
quant_type=target.weight.quant_type,
|
||||||
|
)
|
||||||
|
.cuda(None)
|
||||||
|
.to(old_dev)
|
||||||
|
)
|
||||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||||
target.weight = bnb.nn.Int8Params(
|
target.weight = (
|
||||||
new_weight, requires_grad=False
|
bnb.nn.Int8Params(new_weight, requires_grad=False)
|
||||||
).to(old_dev)
|
.cuda(None)
|
||||||
|
.to(old_dev)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target.weight.data = new_weight.to(old_dev)
|
target.weight.data = new_weight.to(old_dev)
|
||||||
|
|
||||||
@@ -275,7 +286,10 @@ def merge_and_save(
|
|||||||
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
||||||
+ ".safetensors"
|
+ ".safetensors"
|
||||||
)
|
)
|
||||||
st.save_file(out_tensors, str(Path(model_dst) / out_shard_name))
|
|
||||||
|
shard_fn = str(Path(model_dst) / out_shard_name)
|
||||||
|
LOG.info(f"saving tensors to {shard_fn}")
|
||||||
|
st.save_file(out_tensors, shard_fn)
|
||||||
del out_tensors
|
del out_tensors
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user