diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index ccdb57776..9a9ea131e 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -76,7 +76,7 @@ class ReLoRACallback(TrainerCallback): reinit=True, quantized=self.quantised, ) - reset_optimizer(optimizer) + reset_optimizer(optimizer) if self.quantised: self.last_full_model = checkpoint_folder @@ -160,7 +160,9 @@ class ReLoRAScheduler(LRScheduler): def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]: model_name = "model.safetensors" - if not os.path.exists(str(Path(path) / model_name)): + if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( + str(Path(path) / f"{model_name}.index.json") + ): model_name = "pytorch_model.bin" index_path = str(Path(path) / f"{model_name}.index.json") @@ -230,7 +232,9 @@ def merge_and_save( in_tensors = in_tensors["state_dict"] for key in key_list: - if shard_paths[key + ".weight"] != shard_path: + if (key + ".weight") not in shard_paths or shard_paths[ + key + ".weight" + ] != shard_path: continue try: