further simplifying

This commit is contained in:
Dan Saunders
2025-04-23 23:37:41 +00:00
parent bac5568bda
commit cd393fecc3

View File

@@ -37,20 +37,13 @@ class SequenceParallelContext:
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = 0
self.local_world_size = 1
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
self.active = False
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
if self.sequence_parallel_degree > 1:
if self.process_group is None:
self.process_group = dist.group.WORLD
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
def __enter__(self):
self.active = True