add support for SAC
This commit is contained in:
@@ -521,6 +521,11 @@ def train(
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
if cfg.activation_memory_budget is not None:
|
||||
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
|
||||
cfg.activation_memory_budget
|
||||
)
|
||||
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||
(
|
||||
trainer,
|
||||
|
||||
@@ -182,6 +182,7 @@ class AxolotlInputConfig(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||
activation_memory_budget: float | None = None
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user