make sure to pass kwargs when using accelerate
This commit is contained in:
@@ -200,7 +200,9 @@ def train(
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=True, **kwargs
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user