simplifying

This commit is contained in:
Dan Saunders
2025-04-23 23:56:31 +00:00
parent 65ae78009c
commit e5a4e21497

View File

@@ -1,6 +1,5 @@
"""Module for definition of sequence parallel context manager"""
import inspect
import logging
import torch
@@ -22,7 +21,7 @@ class SequenceParallelContext:
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using pre-forward hooks.
during model forward passes using a pre-forward hook.
"""
def __init__(
@@ -41,57 +40,28 @@ class SequenceParallelContext:
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
self.hook_handle: RemovableHandle | None = None
def __enter__(self):
# Define a pre-forward hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(module, args, kwargs):
# Define a forward pre-hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(
module, args, kwargs
): # pylint: disable=unused-argument
# Apply sequence parallelism to kwargs
if kwargs:
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
return args, transformed_kwargs
# If no kwargs but we have args, try to convert them to kwargs
if args and not kwargs:
try:
signature = inspect.signature(module.forward)
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
# Create kwargs from args
new_kwargs = {}
for i, arg in enumerate(args):
if i < len(param_names):
new_kwargs[param_names[i]] = arg
else:
# If we can't map all args, don't transform
return None
# Apply sequence parallelism to the new kwargs
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
# Return empty args and the transformed kwargs
return (), transformed_kwargs
except (ValueError, TypeError):
# If conversion fails, don't transform
return None
# If no args and no kwargs, nothing to transform
return None
kwargs = self.apply_sequence_parallelism(kwargs)
return args, kwargs
# Register the pre-forward hook on the model
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
self.hook_handle = self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
# Remove the forward pre-hook
self.hook_handle.remove()
self.hook_handle = None
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]