simplifying
This commit is contained in:
@@ -39,19 +39,13 @@ class SequenceParallelContext:
|
||||
# Initialize sequence parallel group details
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
self.active = False
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
|
||||
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
||||
def sequence_parallel_pre_hook(module, args, kwargs):
|
||||
if not self.active or self.sequence_parallel_degree <= 1:
|
||||
return None
|
||||
|
||||
# Apply sequence parallelism to kwargs
|
||||
if kwargs:
|
||||
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
||||
@@ -94,8 +88,6 @@ class SequenceParallelContext:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.active = False
|
||||
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
@@ -113,9 +105,6 @@ class SequenceParallelContext:
|
||||
Returns:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
if self.sequence_parallel_degree <= 1 or not self.active:
|
||||
return batch
|
||||
|
||||
# Update ring attention params if needed
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
Reference in New Issue
Block a user