diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 56e378c7f..89f35d7eb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -509,6 +509,7 @@ def train( # Save the trained model and cleanup save_trained_model(cfg, trainer, model, safe_serialization) create_model_card(cfg, trainer) - cleanup_distributed() + if not cfg.use_ray: + cleanup_distributed() return model, tokenizer, trainer