diff --git a/src/axolotl/core/trainer_builder/sft.py b/src/axolotl/core/trainer_builder/sft.py index 88496e8c7..75eddcdf9 100644 --- a/src/axolotl/core/trainer_builder/sft.py +++ b/src/axolotl/core/trainer_builder/sft.py @@ -195,20 +195,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.greater_is_better: training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better - if self.cfg.torch_compile: - if torch._dynamo: # pylint: disable=protected-access - torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access - True + if self.cfg.torch_compile and getattr(torch, "_dynamo", None): + torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access + True + ) + training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile + if self.cfg.torch_compile_backend: + training_arguments_kwargs["torch_compile_backend"] = ( + self.cfg.torch_compile_backend + ) + if self.cfg.torch_compile_mode: + training_arguments_kwargs["torch_compile_mode"] = ( + self.cfg.torch_compile_mode ) - training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile - if self.cfg.torch_compile_backend: - training_arguments_kwargs["torch_compile_backend"] = ( - self.cfg.torch_compile_backend - ) - if self.cfg.torch_compile_mode: - training_arguments_kwargs["torch_compile_mode"] = ( - self.cfg.torch_compile_mode - ) # DDP Config if self.cfg.ddp_timeout: