Fix logic errors
This commit is contained in:
committed by
Wing Lian
parent
b4f2eea2ed
commit
1afbd8af2d
@@ -76,7 +76,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
reinit=True,
|
reinit=True,
|
||||||
quantized=self.quantised,
|
quantized=self.quantised,
|
||||||
)
|
)
|
||||||
reset_optimizer(optimizer)
|
reset_optimizer(optimizer)
|
||||||
|
|
||||||
if self.quantised:
|
if self.quantised:
|
||||||
self.last_full_model = checkpoint_folder
|
self.last_full_model = checkpoint_folder
|
||||||
@@ -160,7 +160,9 @@ class ReLoRAScheduler(LRScheduler):
|
|||||||
|
|
||||||
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
def sharded_paths(path: str, keys: List[str]) -> Dict[str, str]:
|
||||||
model_name = "model.safetensors"
|
model_name = "model.safetensors"
|
||||||
if not os.path.exists(str(Path(path) / model_name)):
|
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||||
|
str(Path(path) / f"{model_name}.index.json")
|
||||||
|
):
|
||||||
model_name = "pytorch_model.bin"
|
model_name = "pytorch_model.bin"
|
||||||
|
|
||||||
index_path = str(Path(path) / f"{model_name}.index.json")
|
index_path = str(Path(path) / f"{model_name}.index.json")
|
||||||
@@ -230,7 +232,9 @@ def merge_and_save(
|
|||||||
in_tensors = in_tensors["state_dict"]
|
in_tensors = in_tensors["state_dict"]
|
||||||
|
|
||||||
for key in key_list:
|
for key in key_list:
|
||||||
if shard_paths[key + ".weight"] != shard_path:
|
if (key + ".weight") not in shard_paths or shard_paths[
|
||||||
|
key + ".weight"
|
||||||
|
] != shard_path:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user