This commit is contained in:
Dan Saunders
2025-04-24 00:02:40 +00:00
parent e5a4e21497
commit 5816433121

View File

@@ -187,8 +187,6 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
@@ -209,6 +207,7 @@ def execute_training(
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)