Fix logic errors
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Charles Goddard
2023-07-25 16:19:53 -07:00
committed by Wing Lian
parent b4f2eea2ed
commit 1afbd8af2d

View File

@@ -76,7 +76,7 @@ class ReLoRACallback(TrainerCallback):
reinit=True,
quantized=self.quantised,
)
reset_optimizer(optimizer)
reset_optimizer(optimizer)
if self.quantised:
self.last_full_model = checkpoint_folder
@@ -160,7 +160,9 @@ class ReLoRAScheduler(LRScheduler):
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)):
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
str(Path(path) / f"{model_name}.index.json")
):
model_name = "pytorch_model.bin"
index_path = str(Path(path) / f"{model_name}.index.json")
@@ -230,7 +232,9 @@ def merge_and_save(
in_tensors = in_tensors["state_dict"]
for key in key_list:
if shard_paths[key + ".weight"] != shard_path:
if (key + ".weight") not in shard_paths or shard_paths[
key + ".weight"
] != shard_path:
continue
try: