simplifying
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
"""Module for definition of sequence parallel context manager"""
|
"""Module for definition of sequence parallel context manager"""
|
||||||
|
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -22,7 +21,7 @@ class SequenceParallelContext:
|
|||||||
Context manager for sequence parallelism operations.
|
Context manager for sequence parallelism operations.
|
||||||
|
|
||||||
This class provides a context that will automatically apply sequence parallelism
|
This class provides a context that will automatically apply sequence parallelism
|
||||||
during model forward passes using pre-forward hooks.
|
during model forward passes using a pre-forward hook.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -41,57 +40,28 @@ class SequenceParallelContext:
|
|||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
# Will store hook handles for removal
|
# Will store hook handles for removal
|
||||||
self.hook_handles: list[RemovableHandle] = []
|
self.hook_handle: RemovableHandle | None = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
# Define a forward pre-hook to apply sequence parallelism with kwargs support
|
||||||
def sequence_parallel_pre_hook(module, args, kwargs):
|
def sequence_parallel_pre_hook(
|
||||||
|
module, args, kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
# Apply sequence parallelism to kwargs
|
# Apply sequence parallelism to kwargs
|
||||||
if kwargs:
|
kwargs = self.apply_sequence_parallelism(kwargs)
|
||||||
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
return args, kwargs
|
||||||
return args, transformed_kwargs
|
|
||||||
|
|
||||||
# If no kwargs but we have args, try to convert them to kwargs
|
|
||||||
if args and not kwargs:
|
|
||||||
try:
|
|
||||||
signature = inspect.signature(module.forward)
|
|
||||||
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
|
||||||
|
|
||||||
# Create kwargs from args
|
|
||||||
new_kwargs = {}
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i < len(param_names):
|
|
||||||
new_kwargs[param_names[i]] = arg
|
|
||||||
else:
|
|
||||||
# If we can't map all args, don't transform
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Apply sequence parallelism to the new kwargs
|
|
||||||
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
|
|
||||||
|
|
||||||
# Return empty args and the transformed kwargs
|
|
||||||
return (), transformed_kwargs
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# If conversion fails, don't transform
|
|
||||||
return None
|
|
||||||
|
|
||||||
# If no args and no kwargs, nothing to transform
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Register the pre-forward hook on the model
|
# Register the pre-forward hook on the model
|
||||||
self.hook_handles.append(
|
self.hook_handle = self.model.register_forward_pre_hook(
|
||||||
self.model.register_forward_pre_hook(
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
sequence_parallel_pre_hook, with_kwargs=True
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Remove all hooks
|
# Remove the forward pre-hook
|
||||||
for handle in self.hook_handles:
|
self.hook_handle.remove()
|
||||||
handle.remove()
|
self.hook_handle = None
|
||||||
self.hook_handles = []
|
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
def apply_sequence_parallelism(
|
||||||
self, batch: dict[str, torch.Tensor]
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
|||||||
Reference in New Issue
Block a user