fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498)
This commit is contained in:
@@ -78,30 +78,21 @@ def patch_parallelism_config():
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@contextlib.contextmanager
|
||||
def _noop_cp_context(
|
||||
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
|
||||
):
|
||||
yield
|
||||
|
||||
self._cp_context = _noop_cp_context
|
||||
return args
|
||||
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
Reference in New Issue
Block a user