fix: solved double sequence partition from SequenceParallelContextManager and Accelerate's native CP (#3498)
This commit is contained in:
@@ -78,30 +78,21 @@ def patch_parallelism_config():
|
|||||||
|
|
||||||
|
|
||||||
def patch_prepare_cp():
|
def patch_prepare_cp():
|
||||||
import functools
|
import contextlib
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
def patched_prepare_cp(self, *args):
|
def patched_prepare_cp(self, *args):
|
||||||
if self.parallelism_config.cp_backend == "deepspeed":
|
if self.parallelism_config.cp_backend == "deepspeed":
|
||||||
return args
|
return args
|
||||||
|
|
||||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
@contextlib.contextmanager
|
||||||
from torch.distributed.tensor.experimental import context_parallel
|
def _noop_cp_context(
|
||||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
buffers=None, buffer_seq_dims=None, no_restore_buffers=None
|
||||||
|
):
|
||||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
yield
|
||||||
set_rotate_method(cp_comm_strategy)
|
|
||||||
|
|
||||||
self._cp_context = functools.partial(
|
|
||||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
|
||||||
)
|
|
||||||
|
|
||||||
for arg in args:
|
|
||||||
if isinstance(arg, torch.nn.Module):
|
|
||||||
_attach_context_parallel_hooks(arg)
|
|
||||||
|
|
||||||
|
self._cp_context = _noop_cp_context
|
||||||
return args
|
return args
|
||||||
|
|
||||||
Accelerator._prepare_cp = patched_prepare_cp
|
Accelerator._prepare_cp = patched_prepare_cp
|
||||||
|
|||||||
Reference in New Issue
Block a user