Compare commits
2 Commits
coderabbit
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7610a02881 | ||
|
|
b0cd54bcb9 |
@@ -610,3 +610,15 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
def compute_loss_context_manager(self):
|
||||
from contextlib import ExitStack
|
||||
|
||||
from torchtune.training import OffloadActivations
|
||||
|
||||
stack = ExitStack()
|
||||
|
||||
stack.enter_context(super().compute_loss_context_manager())
|
||||
stack.enter_context(OffloadActivations())
|
||||
|
||||
return stack
|
||||
|
||||
@@ -2,6 +2,13 @@
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils.checkpoint import (
|
||||
CheckpointPolicy,
|
||||
checkpoint,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
@@ -18,3 +25,32 @@ def hf_grad_checkpoint_offload_wrapper(
|
||||
),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
compute_intensive_ops = [
|
||||
aten.mm.default,
|
||||
aten.bmm.default,
|
||||
aten.addmm.default,
|
||||
]
|
||||
|
||||
|
||||
def policy_fn(ctx, op, *args, **kwargs):
|
||||
if op in compute_intensive_ops:
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
else:
|
||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
|
||||
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
|
||||
|
||||
|
||||
def checkpoint_w_policy(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return checkpoint(
|
||||
decoder_layer,
|
||||
*args,
|
||||
use_reentrant=use_reentrant,
|
||||
context_fn=context_fn,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user