add support for SAC

This commit is contained in:
Wing Lian
2025-05-23 10:33:02 -04:00
parent a27b909c5c
commit 5930c91a12
2 changed files with 6 additions and 0 deletions

View File

@@ -521,6 +521,11 @@ def train(
"""
print_axolotl_text_art()
if cfg.activation_memory_budget is not None:
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
cfg.activation_memory_budget
)
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,

View File

@@ -182,6 +182,7 @@ class AxolotlInputConfig(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
activation_memory_budget: float | None = None
unfrozen_parameters: list[str] | None = None