Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e1c7a61243 fix reentrant when using offloading 2025-09-14 10:42:15 -04:00
2 changed files with 25 additions and 5 deletions

View File

@@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from functools import partial
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
@@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer):
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
if use_reentrant:
checkpoint_wrapper_fn = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)
else:
checkpoint_wrapper_fn = checkpoint_wrapper
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
auto_wrap_policy=auto_wrap_policy,
**kwargs,
)
def get_lora_act_offloading_ctx_manager(

View File

@@ -224,21 +224,27 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
use_reentrant = None
if (
self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
):
use_reentrant = True
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""