chore: simplify dynamo check

This commit is contained in:
NanoCode012
2025-05-15 20:28:13 +07:00
parent 64a57ebb62
commit 7fd05c19f7

View File

@@ -195,20 +195,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
# DDP Config
if self.cfg.ddp_timeout: