diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 6422bde82..283bc1727 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1080,6 +1080,19 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_activation_memory_budget_w_compile(cls, data): + if data.get("activation_memory_budget") is not None and not data.get( + "torch_compile" + ): + LOG.warning( + "activation_memory_budget is enabled, but torch_compile is not set. " + "Automatically setting torch_compile to true." + ) + data["torch_compile"] = True + return data + @model_validator(mode="before") @classmethod def check_npu_config(cls, data):