add support for SAC
This commit is contained in:
@@ -521,6 +521,11 @@ def train(
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
if cfg.activation_memory_budget is not None:
|
||||||
|
torch._functorch.config.activation_memory_budget = ( # pylint: disable=protected-access
|
||||||
|
cfg.activation_memory_budget
|
||||||
|
)
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
@@ -182,6 +182,7 @@ class AxolotlInputConfig(
|
|||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
|
activation_memory_budget: float | None = None
|
||||||
|
|
||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user