make sure to pass kwargs when using accelerate

This commit is contained in:
Wing Lian
2025-02-06 14:00:15 -05:00
parent c82cbdc6d9
commit 3df4df868c

View File

@@ -200,7 +200,9 @@ def train(
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
do_cli_train(
cloud_config=cloud, config=config, accelerate=True, **kwargs
)
else:
accelerate_args = []
if "main_process_port" in kwargs: