Compare commits
2 Commits
fix/diffus
...
activation
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7610a02881 | ||
|
|
b0cd54bcb9 |
@@ -610,3 +610,15 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def compute_loss_context_manager(self):
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
|
from torchtune.training import OffloadActivations
|
||||||
|
|
||||||
|
stack = ExitStack()
|
||||||
|
|
||||||
|
stack.enter_context(super().compute_loss_context_manager())
|
||||||
|
stack.enter_context(OffloadActivations())
|
||||||
|
|
||||||
|
return stack
|
||||||
|
|||||||
@@ -2,6 +2,13 @@
|
|||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
CheckpointPolicy,
|
||||||
|
checkpoint,
|
||||||
|
create_selective_checkpoint_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
@@ -18,3 +25,32 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
),
|
),
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
compute_intensive_ops = [
|
||||||
|
aten.mm.default,
|
||||||
|
aten.bmm.default,
|
||||||
|
aten.addmm.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def policy_fn(ctx, op, *args, **kwargs):
|
||||||
|
if op in compute_intensive_ops:
|
||||||
|
return CheckpointPolicy.MUST_SAVE
|
||||||
|
else:
|
||||||
|
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||||
|
|
||||||
|
|
||||||
|
context_fn = partial(create_selective_checkpoint_contexts, policy_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_w_policy(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
return checkpoint(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
use_reentrant=use_reentrant,
|
||||||
|
context_fn=context_fn,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user