simplifying
This commit is contained in:
@@ -39,19 +39,13 @@ class SequenceParallelContext:
|
|||||||
# Initialize sequence parallel group details
|
# Initialize sequence parallel group details
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
self.active = False
|
|
||||||
|
|
||||||
# Will store hook handles for removal
|
# Will store hook handles for removal
|
||||||
self.hook_handles: list[RemovableHandle] = []
|
self.hook_handles: list[RemovableHandle] = []
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.active = True
|
|
||||||
|
|
||||||
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
||||||
def sequence_parallel_pre_hook(module, args, kwargs):
|
def sequence_parallel_pre_hook(module, args, kwargs):
|
||||||
if not self.active or self.sequence_parallel_degree <= 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Apply sequence parallelism to kwargs
|
# Apply sequence parallelism to kwargs
|
||||||
if kwargs:
|
if kwargs:
|
||||||
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
||||||
@@ -94,8 +88,6 @@ class SequenceParallelContext:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.active = False
|
|
||||||
|
|
||||||
# Remove all hooks
|
# Remove all hooks
|
||||||
for handle in self.hook_handles:
|
for handle in self.hook_handles:
|
||||||
handle.remove()
|
handle.remove()
|
||||||
@@ -113,9 +105,6 @@ class SequenceParallelContext:
|
|||||||
Returns:
|
Returns:
|
||||||
Sliced batch dictionary.
|
Sliced batch dictionary.
|
||||||
"""
|
"""
|
||||||
if self.sequence_parallel_degree <= 1 or not self.active:
|
|
||||||
return batch
|
|
||||||
|
|
||||||
# Update ring attention params if needed
|
# Update ring attention params if needed
|
||||||
if batch.get("position_ids") is not None:
|
if batch.get("position_ids") is not None:
|
||||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||||
|
|||||||
Reference in New Issue
Block a user