further simplifying
This commit is contained in:
@@ -37,20 +37,13 @@ class SequenceParallelContext:
|
|||||||
self.process_group = get_ring_attn_group()
|
self.process_group = get_ring_attn_group()
|
||||||
|
|
||||||
# Initialize sequence parallel group details
|
# Initialize sequence parallel group details
|
||||||
self.local_rank = 0
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
self.local_world_size = 1
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
# Will store hook handles for removal
|
# Will store hook handles for removal
|
||||||
self.hook_handles: list[RemovableHandle] = []
|
self.hook_handles: list[RemovableHandle] = []
|
||||||
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
if self.process_group is None:
|
|
||||||
self.process_group = dist.group.WORLD
|
|
||||||
|
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
|
||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.active = True
|
self.active = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user