further simplifying
This commit is contained in:
@@ -37,20 +37,13 @@ class SequenceParallelContext:
|
||||
self.process_group = get_ring_attn_group()
|
||||
|
||||
# Initialize sequence parallel group details
|
||||
self.local_rank = 0
|
||||
self.local_world_size = 1
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
self.active = False
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
if self.sequence_parallel_degree > 1:
|
||||
if self.process_group is None:
|
||||
self.process_group = dist.group.WORLD
|
||||
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user