diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 6b3bfbd57..2d472ed10 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -84,8 +84,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(), ), ) - return trainer.fit() - return do_train(parsed_cfg, parsed_cli_args) + + trainer.fit() + return + + do_train(parsed_cfg, parsed_cli_args) def ray_train_func(kwargs: dict):