monkeypatch.gradient_checkpointing.offload_cpu

monkeypatch.gradient_checkpointing.offload_cpu

CPU offloaded checkpointing

Classes

Name Description
CPU_Offloaded_Gradient_Checkpointer Saves VRAM by smartly offloading to RAM.
CheckpointFunctionWithCPUOffload This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it’s 24GB saved per gpu: ((100_000*4096)*2*32/2**30)

CPU_Offloaded_Gradient_Checkpointer

monkeypatch.gradient_checkpointing.offload_cpu.CPU_Offloaded_Gradient_Checkpointer(
)

Saves VRAM by smartly offloading to RAM. Tiny hit to performance, since we mask the movement via non blocking calls.

CheckpointFunctionWithCPUOffload

monkeypatch.gradient_checkpointing.offload_cpu.CheckpointFunctionWithCPUOffload(
)

This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it’s 24GB saved per gpu: ((100_000*4096)*2*32/2**30) In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.