fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498)

This commit is contained in:
Lorenzo Baraldi
2026-03-20 10:27:24 +01:00
committed by GitHub
parent c13cb7c853
commit 038ffe3f26

View File

@@ -78,30 +78,21 @@ def patch_parallelism_config():
def patch_prepare_cp():
import functools
import contextlib
import torch
from accelerate import Accelerator
def patched_prepare_cp(self, *args):
if self.parallelism_config.cp_backend == "deepspeed":
return args
from accelerate.big_modeling import _attach_context_parallel_hooks
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(
context_parallel, mesh=self.torch_device_mesh["cp"]
)
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@contextlib.contextmanager
def _noop_cp_context(
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
):
yield
self._cp_context = _noop_cp_context
return args
Accelerator._prepare_cp = patched_prepare_cp