This commit is contained in:
Dan Saunders
2025-04-24 00:02:40 +00:00
parent e5a4e21497
commit 5816433121

View File

@@ -187,8 +187,6 @@ def execute_training(
trainer: The configured trainer object. trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable. resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
""" """
LOG.info("Starting trainer...")
# Define the context managers to use # Define the context managers to use
flash_context = ( flash_context = (
torch.backends.cuda.sdp_kernel( torch.backends.cuda.sdp_kernel(
@@ -209,6 +207,7 @@ def execute_training(
else nullcontext() else nullcontext()
) )
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context: with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)