From 8619b2d8558f71a45ee15bf029fc58fc230adc82 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 17 Jul 2024 15:38:07 -0400 Subject: [PATCH] add torch_compile_mode options (#1763) [skip ci] * add torch_compile_mode options * make sure n_gpu is an int --- src/axolotl/cli/__init__.py | 2 +- src/axolotl/core/trainer_builder.py | 4 ++++ .../utils/config/models/input/v0_4_1/__init__.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 7ec3f524a..5966d5931 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg, capabilities={ "bf16": is_torch_bf16_gpu_available(), - "n_gpu": os.environ.get("WORLD_SIZE", 1), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0358ad4e6..b0eea55b1 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1287,6 +1287,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "torch_compile_backend" ] = self.cfg.torch_compile_backend + if self.cfg.torch_compile_mode: + training_arguments_kwargs[ + "torch_compile_mode" + ] = self.cfg.torch_compile_mode # DDP Config if self.cfg.ddp_timeout: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 32bb1f5b6..f0c6fa0ea 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -608,6 +608,9 @@ class AxolotlInputConfig( torch_compile: Optional[bool] = None torch_compile_backend: Optional[str] = None + torch_compile_mode: Optional[ + Literal["default", "reduce-overhead", "max-autotune"] + ] = None max_steps: Optional[int] = None warmup_steps: Optional[int] = None @@ -1161,6 +1164,15 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_torch_compile_deepspeed(cls, data): + if data.get("deepspeed") and data.get("torch_compile"): + raise ValueError( + "torch_compile should be set within your deepspeed config file" + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options"""