simplifying
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""Module for definition of sequence parallel context manager"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import torch
|
||||
@@ -22,7 +21,7 @@ class SequenceParallelContext:
|
||||
Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using pre-forward hooks.
|
||||
during model forward passes using a pre-forward hook.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -41,57 +40,28 @@ class SequenceParallelContext:
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
self.hook_handle: RemovableHandle | None = None
|
||||
|
||||
def __enter__(self):
|
||||
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
||||
def sequence_parallel_pre_hook(module, args, kwargs):
|
||||
# Define a forward pre-hook to apply sequence parallelism with kwargs support
|
||||
def sequence_parallel_pre_hook(
|
||||
module, args, kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
# Apply sequence parallelism to kwargs
|
||||
if kwargs:
|
||||
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
||||
return args, transformed_kwargs
|
||||
|
||||
# If no kwargs but we have args, try to convert them to kwargs
|
||||
if args and not kwargs:
|
||||
try:
|
||||
signature = inspect.signature(module.forward)
|
||||
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
||||
|
||||
# Create kwargs from args
|
||||
new_kwargs = {}
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(param_names):
|
||||
new_kwargs[param_names[i]] = arg
|
||||
else:
|
||||
# If we can't map all args, don't transform
|
||||
return None
|
||||
|
||||
# Apply sequence parallelism to the new kwargs
|
||||
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
|
||||
|
||||
# Return empty args and the transformed kwargs
|
||||
return (), transformed_kwargs
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, don't transform
|
||||
return None
|
||||
|
||||
# If no args and no kwargs, nothing to transform
|
||||
return None
|
||||
kwargs = self.apply_sequence_parallelism(kwargs)
|
||||
return args, kwargs
|
||||
|
||||
# Register the pre-forward hook on the model
|
||||
self.hook_handles.append(
|
||||
self.model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
self.hook_handle = self.model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
# Remove the forward pre-hook
|
||||
self.hook_handle.remove()
|
||||
self.hook_handle = None
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
|
||||
Reference in New Issue
Block a user