simplifying

This commit is contained in:
Dan Saunders
2025-04-23 23:49:11 +00:00
parent 7e5168ad74
commit 65ae78009c

View File

@@ -39,19 +39,13 @@ class SequenceParallelContext:
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
self.active = False
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
def __enter__(self):
self.active = True
# Define a pre-forward hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(module, args, kwargs):
if not self.active or self.sequence_parallel_degree <= 1:
return None
# Apply sequence parallelism to kwargs
if kwargs:
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
@@ -94,8 +88,6 @@ class SequenceParallelContext:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.active = False
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
@@ -113,9 +105,6 @@ class SequenceParallelContext:
Returns:
Sliced batch dictionary.
"""
if self.sequence_parallel_degree <= 1 or not self.active:
return batch
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])