From 5930c91a12fa2280990733e28ae0ff566df0f734 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 23 May 2025 10:33:02 -0400 Subject: [PATCH] add support for SAC --- src/axolotl/train.py | 5 +++++ src/axolotl/utils/schemas/config.py | 1 + 2 files changed, 6 insertions(+) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f..b3ae09750 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -521,6 +521,11 @@ def train( """ print_axolotl_text_art() + if cfg.activation_memory_budget is not None: + torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access + cfg.activation_memory_budget + ) + # Setup model, tokenizer, (causal or RLHF) trainer, etc. ( trainer, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 8ae9d5c04..6422bde82 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -182,6 +182,7 @@ class AxolotlInputConfig( default=False ) gradient_checkpointing_kwargs: dict[str, Any] | None = None + activation_memory_budget: float | None = None unfrozen_parameters: list[str] | None = None