commenting out unused
This commit is contained in:
@@ -42,7 +42,6 @@ from torch.distributed.tensor.experimental._attention import set_rotate_method
|
|||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
def _get_sdpa_context() -> (
|
def _get_sdpa_context() -> (
|
||||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||||
@@ -103,7 +102,21 @@ def get_context_parallel_manager(
|
|||||||
# TODO: context parallel for multimodal models requires extra work
|
# TODO: context parallel for multimodal models requires extra work
|
||||||
# if not isinstance(model, TransformerDecoder):
|
# if not isinstance(model, TransformerDecoder):
|
||||||
# raise ValueError("Context parallel is only supported for text models")
|
# raise ValueError("Context parallel is only supported for text models")
|
||||||
model_buffers = list(model.buffers())
|
# model_buffers = list(model.buffers())
|
||||||
|
|
||||||
|
# def get_all_buffers(module, prefix=""):
|
||||||
|
# buffers = {}
|
||||||
|
# for name, buffer in module.named_buffers(recurse=False):
|
||||||
|
# full_name = f"{prefix}.{name}" if prefix else name
|
||||||
|
# buffers[full_name] = buffer
|
||||||
|
|
||||||
|
# for name, child in module.named_children():
|
||||||
|
# child_prefix = f"{prefix}.{name}" if prefix else name
|
||||||
|
# buffers.update(get_all_buffers(child, child_prefix))
|
||||||
|
|
||||||
|
# return buffers
|
||||||
|
|
||||||
|
# model_buffers = get_all_buffers(model)
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def context(model_inputs: list[torch.Tensor]):
|
def context(model_inputs: list[torch.Tensor]):
|
||||||
@@ -114,10 +127,13 @@ def get_context_parallel_manager(
|
|||||||
"Context parallel with flex attention is not yet supported"
|
"Context parallel with flex attention is not yet supported"
|
||||||
)
|
)
|
||||||
set_rotate_method("allgather")
|
set_rotate_method("allgather")
|
||||||
|
|
||||||
cp_context = context_parallel(
|
cp_context = context_parallel(
|
||||||
world_mesh["cp"],
|
world_mesh["cp"],
|
||||||
buffers=model_inputs + model_buffers,
|
# buffers=model_inputs + model_buffers,
|
||||||
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
|
buffers=model_inputs,
|
||||||
|
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
|
||||||
|
buffer_seq_dims=[1] * len(model_inputs),
|
||||||
no_restore_buffers=set(model_inputs),
|
no_restore_buffers=set(model_inputs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user