From cbcc795bb3d3ddf38497f8989888063e8f8ed85a Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Jun 2025 01:53:13 +0000 Subject: [PATCH] commenting out unused --- .../context_parallel/distributed.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py index 13adeb132..469369fc6 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py @@ -42,7 +42,6 @@ from torch.distributed.tensor.experimental._attention import set_rotate_method from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention.flex_attention import BlockMask -from axolotl.utils.dict import DictDefault def _get_sdpa_context() -> ( Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]] @@ -103,7 +102,21 @@ def get_context_parallel_manager( # TODO: context parallel for multimodal models requires extra work # if not isinstance(model, TransformerDecoder): # raise ValueError("Context parallel is only supported for text models") - model_buffers = list(model.buffers()) + # model_buffers = list(model.buffers()) + + # def get_all_buffers(module, prefix=""): + # buffers = {} + # for name, buffer in module.named_buffers(recurse=False): + # full_name = f"{prefix}.{name}" if prefix else name + # buffers[full_name] = buffer + + # for name, child in module.named_children(): + # child_prefix = f"{prefix}.{name}" if prefix else name + # buffers.update(get_all_buffers(child, child_prefix)) + + # return buffers + + # model_buffers = get_all_buffers(model) @contextlib.contextmanager def context(model_inputs: list[torch.Tensor]): @@ -114,10 +127,13 @@ def get_context_parallel_manager( "Context parallel with flex attention is not yet supported" ) set_rotate_method("allgather") + cp_context = context_parallel( world_mesh["cp"], - buffers=model_inputs + model_buffers, - buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), + # buffers=model_inputs + model_buffers, + buffers=model_inputs, + # buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers), + buffer_seq_dims=[1] * len(model_inputs), no_restore_buffers=set(model_inputs), )