accommodate both training context managers

This commit is contained in:
Dan Saunders
2025-04-23 23:40:45 +00:00
parent cd393fecc3
commit 7e5168ad74

View File

@@ -6,6 +6,7 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict
@@ -187,22 +188,28 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
elif cfg.sequence_parallel_degree > 1:
with SequenceParallelContext(
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContext(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)