commenting out unused

This commit is contained in:
Dan Saunders
2025-06-16 01:53:13 +00:00
parent e34b6f4dfe
commit cbcc795bb3

View File

@@ -42,7 +42,6 @@ from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask
from axolotl.utils.dict import DictDefault
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
@@ -103,7 +102,21 @@ def get_context_parallel_manager(
# TODO: context parallel for multimodal models requires extra work
# if not isinstance(model, TransformerDecoder):
# raise ValueError("Context parallel is only supported for text models")
model_buffers = list(model.buffers())
# model_buffers = list(model.buffers())
# def get_all_buffers(module, prefix=""):
# buffers = {}
# for name, buffer in module.named_buffers(recurse=False):
# full_name = f"{prefix}.{name}" if prefix else name
# buffers[full_name] = buffer
# for name, child in module.named_children():
# child_prefix = f"{prefix}.{name}" if prefix else name
# buffers.update(get_all_buffers(child, child_prefix))
# return buffers
# model_buffers = get_all_buffers(model)
@contextlib.contextmanager
def context(model_inputs: list[torch.Tensor]):
@@ -114,10 +127,13 @@ def get_context_parallel_manager(
"Context parallel with flex attention is not yet supported"
)
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
buffers=model_inputs + model_buffers,
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
# buffers=model_inputs + model_buffers,
buffers=model_inputs,
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
buffer_seq_dims=[1] * len(model_inputs),
no_restore_buffers=set(model_inputs),
)