diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 7ec3f524a..5966d5931 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): cfg, capabilities={ "bf16": is_torch_bf16_gpu_available(), - "n_gpu": os.environ.get("WORLD_SIZE", 1), + "n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "compute_capability": gpu_version, }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0358ad4e6..b0eea55b1 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1287,6 +1287,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "torch_compile_backend" ] = self.cfg.torch_compile_backend + if self.cfg.torch_compile_mode: + training_arguments_kwargs[ + "torch_compile_mode" + ] = self.cfg.torch_compile_mode # DDP Config if self.cfg.ddp_timeout: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 32bb1f5b6..f0c6fa0ea 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -608,6 +608,9 @@ class AxolotlInputConfig( torch_compile: Optional[bool] = None torch_compile_backend: Optional[str] = None + torch_compile_mode: Optional[ + Literal["default", "reduce-overhead", "max-autotune"] + ] = None max_steps: Optional[int] = None warmup_steps: Optional[int] = None @@ -1161,6 +1164,15 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_torch_compile_deepspeed(cls, data): + if data.get("deepspeed") and data.get("torch_compile"): + raise ValueError( + "torch_compile should be set within your deepspeed config file" + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options"""