Fix logic errors
This commit is contained in:
committed by
Wing Lian
parent
b4f2eea2ed
commit
1afbd8af2d
@@ -76,7 +76,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
reinit=True,
|
||||
quantized=self.quantised,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantised:
|
||||
self.last_full_model = checkpoint_folder
|
||||
@@ -160,7 +160,9 @@ class ReLoRAScheduler(LRScheduler):
|
||||
|
||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)):
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
str(Path(path) / f"{model_name}.index.json")
|
||||
):
|
||||
model_name = "pytorch_model.bin"
|
||||
|
||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||
@@ -230,7 +232,9 @@ def merge_and_save(
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
for key in key_list:
|
||||
if shard_paths[key + ".weight"] != shard_path:
|
||||
if (key + ".weight") not in shard_paths or shard_paths[
|
||||
key + ".weight"
|
||||
] != shard_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user