further simplifying

This commit is contained in:
Dan Saunders
2025-04-23 23:37:41 +00:00
parent bac5568bda
commit cd393fecc3

View File

@@ -37,20 +37,13 @@ class SequenceParallelContext:
self.process_group = get_ring_attn_group() self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details # Initialize sequence parallel group details
self.local_rank = 0 self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = 1 self.local_world_size = dist.get_world_size(self.process_group)
self.active = False self.active = False
# Will store hook handles for removal # Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = [] self.hook_handles: list[RemovableHandle] = []
if self.sequence_parallel_degree > 1:
if self.process_group is None:
self.process_group = dist.group.WORLD
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
def __enter__(self): def __enter__(self):
self.active = True self.active = True