make sure to pass kwargs when using accelerate
This commit is contained in:
@@ -200,7 +200,9 @@ def train(
|
|||||||
try:
|
try:
|
||||||
if accelerate:
|
if accelerate:
|
||||||
if cloud:
|
if cloud:
|
||||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
do_cli_train(
|
||||||
|
cloud_config=cloud, config=config, accelerate=True, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
accelerate_args = []
|
accelerate_args = []
|
||||||
if "main_process_port" in kwargs:
|
if "main_process_port" in kwargs:
|
||||||
|
|||||||
Reference in New Issue
Block a user