fix num_processes in passing to accelerate

This commit is contained in:
Wing Lian
2025-02-06 13:39:46 -05:00
parent 4f9c57e95d
commit ecea44c902

View File

@@ -209,7 +209,7 @@ def train(
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]