From cd393fecc3a4bb776b91fbca3b86ed93a4299b14 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 23 Apr 2025 23:37:41 +0000 Subject: [PATCH] further simplifying --- src/axolotl/core/trainers/sp.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainers/sp.py b/src/axolotl/core/trainers/sp.py index 41e31a6ed..28e627bd0 100644 --- a/src/axolotl/core/trainers/sp.py +++ b/src/axolotl/core/trainers/sp.py @@ -37,20 +37,13 @@ class SequenceParallelContext: self.process_group = get_ring_attn_group() # Initialize sequence parallel group details - self.local_rank = 0 - self.local_world_size = 1 + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) self.active = False # Will store hook handles for removal self.hook_handles: list[RemovableHandle] = [] - if self.sequence_parallel_degree > 1: - if self.process_group is None: - self.process_group = dist.group.WORLD - - self.local_rank = dist.get_rank(self.process_group) - self.local_world_size = dist.get_world_size(self.process_group) - def __enter__(self): self.active = True