diff --git a/src/axolotl/core/trainers/sp.py b/src/axolotl/core/trainers/sp.py index 28e627bd0..1cff03382 100644 --- a/src/axolotl/core/trainers/sp.py +++ b/src/axolotl/core/trainers/sp.py @@ -39,19 +39,13 @@ class SequenceParallelContext: # Initialize sequence parallel group details self.local_rank = dist.get_rank(self.process_group) self.local_world_size = dist.get_world_size(self.process_group) - self.active = False # Will store hook handles for removal self.hook_handles: list[RemovableHandle] = [] def __enter__(self): - self.active = True - # Define a pre-forward hook to apply sequence parallelism with kwargs support def sequence_parallel_pre_hook(module, args, kwargs): - if not self.active or self.sequence_parallel_degree <= 1: - return None - # Apply sequence parallelism to kwargs if kwargs: transformed_kwargs = self.apply_sequence_parallelism(kwargs) @@ -94,8 +88,6 @@ class SequenceParallelContext: return self def __exit__(self, exc_type, exc_val, exc_tb): - self.active = False - # Remove all hooks for handle in self.hook_handles: handle.remove() @@ -113,9 +105,6 @@ class SequenceParallelContext: Returns: Sliced batch dictionary. """ - if self.sequence_parallel_degree <= 1 or not self.active: - return batch - # Update ring attention params if needed if batch.get("position_ids") is not None: update_ring_attn_params(position_ids=batch["position_ids"])