Compare commits
1 Commits
coderabbit
...
reentrant-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1c7a61243 |
@@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from functools import partial
|
||||
|
||||
from peft import PeftModel
|
||||
from torch import nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
checkpoint_wrapper,
|
||||
CheckpointImpl,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from transformers import GradientCheckpointingLayer, Trainer
|
||||
@@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer):
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
|
||||
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
|
||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||
if use_reentrant:
|
||||
checkpoint_wrapper_fn = partial(
|
||||
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
|
||||
)
|
||||
else:
|
||||
checkpoint_wrapper_fn = checkpoint_wrapper
|
||||
apply_activation_checkpointing(
|
||||
model,
|
||||
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
|
||||
auto_wrap_policy=auto_wrap_policy,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_lora_act_offloading_ctx_manager(
|
||||
|
||||
@@ -224,21 +224,27 @@ class ModelLoader:
|
||||
):
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
self._apply_activation_checkpointing()
|
||||
use_reentrant = None
|
||||
if (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
|
||||
):
|
||||
use_reentrant = True
|
||||
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
|
||||
self._resize_token_embeddings()
|
||||
self._adjust_model_config()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||
|
||||
def _apply_activation_checkpointing(self):
|
||||
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
|
||||
if self.cfg.activation_offloading is True:
|
||||
from axolotl.core.trainers.mixins.activation_checkpointing import (
|
||||
ac_wrap_hf_model,
|
||||
)
|
||||
|
||||
# ^^ importing this at the module level breaks plugins
|
||||
ac_wrap_hf_model(self.model)
|
||||
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
|
||||
Reference in New Issue
Block a user