diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 4eb86b6a2..ccdb57776 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -78,7 +78,8 @@ class ReLoRACallback(TrainerCallback): ) reset_optimizer(optimizer) - self.last_full_model = checkpoint_folder + if self.quantised: + self.last_full_model = checkpoint_folder self.num_lora_restarts += 1 return control @@ -95,8 +96,13 @@ class ReLoRACallback(TrainerCallback): args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", ) - if state.global_step >= self.relora_steps: + if ( + state.global_step >= self.relora_steps + and state.global_step % self.relora_steps != 0 + ): if self.quantised and self.last_full_model != checkpoint_folder: + # ensure the latest full parameter save is in the latest checkpoint + # folder, so that automatic pruning of checkpoints does not remove it LOG.info(f"moving last full parameter save to {checkpoint_folder}") chunks = glob.glob( f"{self.last_full_model}/model*.safetensors" @@ -249,18 +255,23 @@ def merge_and_save( for adapter_name in target.lora_embedding_A: target.reset_lora_parameters(adapter_name) - old_dev = target.weight.device if isinstance(target, peft.tuners.lora.Linear4bit): - target.weight = bnb.nn.Params4bit( - new_weight, - requires_grad=False, - compress_statistics=target.weight.compress_statistics, - quant_type=target.weight.quant_type, - ).to(old_dev) + target.weight = ( + bnb.nn.Params4bit( + new_weight, + requires_grad=False, + compress_statistics=target.weight.compress_statistics, + quant_type=target.weight.quant_type, + ) + .cuda(None) + .to(old_dev) + ) elif isinstance(target, peft.tuners.lora.Linear8bitLt): - target.weight = bnb.nn.Int8Params( - new_weight, requires_grad=False - ).to(old_dev) + target.weight = ( + bnb.nn.Int8Params(new_weight, requires_grad=False) + .cuda(None) + .to(old_dev) + ) else: target.weight.data = new_weight.to(old_dev) @@ -275,7 +286,10 @@ def merge_and_save( out_shard_name.replace("pytorch_model", "model").rstrip(".bin") + ".safetensors" ) - st.save_file(out_tensors, str(Path(model_dst) / out_shard_name)) + + shard_fn = str(Path(model_dst) / out_shard_name) + LOG.info(f"saving tensors to {shard_fn}") + st.save_file(out_tensors, shard_fn) del out_tensors torch.cuda.empty_cache()