From 7e5168ad74b5f0f5089a994888c6590d1eaf5323 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 23 Apr 2025 23:40:45 +0000 Subject: [PATCH] accommodate both training context managers --- src/axolotl/train.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index f098be475..0fe569b1e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -6,6 +6,7 @@ import os import signal import sys import weakref +from contextlib import nullcontext from pathlib import Path from typing import Any, Dict @@ -187,22 +188,28 @@ def execute_training( resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... + + # Define the context managers to use + flash_context = ( + torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True, - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - elif cfg.sequence_parallel_degree > 1: - with SequenceParallelContext( + ) + if cfg.flash_optimum + else nullcontext() + ) + sequence_parallel_context = ( + SequenceParallelContext( model=trainer.model, sequence_parallel_degree=cfg.sequence_parallel_degree, ring_attn_func=cfg.ring_attn_func, - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - else: + ) + if cfg.sequence_parallel_degree > 1 + else nullcontext() + ) + + with flash_context, sequence_parallel_context: trainer.train(resume_from_checkpoint=resume_from_checkpoint)