diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 41f184abc..a693236d3 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -566,6 +566,10 @@ def train( resume_from_checkpoint = determine_resume_checkpoint(cfg) execute_training(cfg, trainer, resume_from_checkpoint) + # clear cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # Save the trained model and cleanup save_trained_model(cfg, trainer, model, safe_serialization) create_model_card(cfg, trainer)