diff --git a/src/axolotl/monkeypatch/accelerate/parallelism_config.py b/src/axolotl/monkeypatch/accelerate/parallelism_config.py index 9b71e914a..ebd9a6f0d 100644 --- a/src/axolotl/monkeypatch/accelerate/parallelism_config.py +++ b/src/axolotl/monkeypatch/accelerate/parallelism_config.py @@ -78,30 +78,21 @@ def patch_parallelism_config(): def patch_prepare_cp(): - import functools + import contextlib - import torch from accelerate import Accelerator def patched_prepare_cp(self, *args): if self.parallelism_config.cp_backend == "deepspeed": return args - from accelerate.big_modeling import _attach_context_parallel_hooks - from torch.distributed.tensor.experimental import context_parallel - from torch.distributed.tensor.experimental._attention import set_rotate_method - - cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy - set_rotate_method(cp_comm_strategy) - - self._cp_context = functools.partial( - context_parallel, mesh=self.torch_device_mesh["cp"] - ) - - for arg in args: - if isinstance(arg, torch.nn.Module): - _attach_context_parallel_hooks(arg) + @contextlib.contextmanager + def _noop_cp_context( + buffers=None, buffer_seq_dims=None, no_restore_buffers=None + ): + yield + self._cp_context = _noop_cp_context return args Accelerator._prepare_cp = patched_prepare_cp