From e5a4e21497adde1dc831243df7464076680ea7cc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 23 Apr 2025 23:56:31 +0000 Subject: [PATCH] simplifying --- src/axolotl/core/trainers/sp.py | 56 ++++++++------------------------- 1 file changed, 13 insertions(+), 43 deletions(-) diff --git a/src/axolotl/core/trainers/sp.py b/src/axolotl/core/trainers/sp.py index 1cff03382..ba6cb86a6 100644 --- a/src/axolotl/core/trainers/sp.py +++ b/src/axolotl/core/trainers/sp.py @@ -1,6 +1,5 @@ """Module for definition of sequence parallel context manager""" -import inspect import logging import torch @@ -22,7 +21,7 @@ class SequenceParallelContext: Context manager for sequence parallelism operations. This class provides a context that will automatically apply sequence parallelism - during model forward passes using pre-forward hooks. + during model forward passes using a pre-forward hook. """ def __init__( @@ -41,57 +40,28 @@ class SequenceParallelContext: self.local_world_size = dist.get_world_size(self.process_group) # Will store hook handles for removal - self.hook_handles: list[RemovableHandle] = [] + self.hook_handle: RemovableHandle | None = None def __enter__(self): - # Define a pre-forward hook to apply sequence parallelism with kwargs support - def sequence_parallel_pre_hook(module, args, kwargs): + # Define a forward pre-hook to apply sequence parallelism with kwargs support + def sequence_parallel_pre_hook( + module, args, kwargs + ): # pylint: disable=unused-argument # Apply sequence parallelism to kwargs - if kwargs: - transformed_kwargs = self.apply_sequence_parallelism(kwargs) - return args, transformed_kwargs - - # If no kwargs but we have args, try to convert them to kwargs - if args and not kwargs: - try: - signature = inspect.signature(module.forward) - param_names = list(signature.parameters.keys())[1:] # Skip 'self' - - # Create kwargs from args - new_kwargs = {} - for i, arg in enumerate(args): - if i < len(param_names): - new_kwargs[param_names[i]] = arg - else: - # If we can't map all args, don't transform - return None - - # Apply sequence parallelism to the new kwargs - transformed_kwargs = self.apply_sequence_parallelism(new_kwargs) - - # Return empty args and the transformed kwargs - return (), transformed_kwargs - except (ValueError, TypeError): - # If conversion fails, don't transform - return None - - # If no args and no kwargs, nothing to transform - return None + kwargs = self.apply_sequence_parallelism(kwargs) + return args, kwargs # Register the pre-forward hook on the model - self.hook_handles.append( - self.model.register_forward_pre_hook( - sequence_parallel_pre_hook, with_kwargs=True - ) + self.hook_handle = self.model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True ) return self def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] + # Remove the forward pre-hook + self.hook_handle.remove() + self.hook_handle = None def apply_sequence_parallelism( self, batch: dict[str, torch.Tensor]