diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 819616425..d5dd431c1 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -218,6 +218,7 @@ def execute_training( gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, heads_k_stride=cfg.heads_k_stride, + gather_outputs=cfg.rl is RLType.GRPO, ) ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 491cb9877..f429cd2ae 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -174,6 +174,8 @@ class SequenceParallelContextManager: ring_attn_func: Which ring attention function to use. Currently unused. heads_k_stride: Sequence parallelism K head stride size. Passed through to `varlen_llama3` `ring_flash_attn` implementation. + gather_outputs: Whether to gather outputs after model forward pass across the + sequence parallel group. """ def __init__( @@ -183,12 +185,15 @@ class SequenceParallelContextManager: gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, heads_k_stride: int | None, + gather_outputs: bool, ): self.models = models self.sequence_parallel_degree = sequence_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func self.heads_k_stride = heads_k_stride + self.gather_outputs = gather_outputs + self._register_ring_attn() # Set distributed info for local rank @@ -277,16 +282,17 @@ class SequenceParallelContextManager: return output - # Register both hooks + # Register hooks for model in self.models: self.hook_handles.append( model.register_forward_pre_hook( sequence_parallel_pre_hook, with_kwargs=True ) ) - self.hook_handles.append( - model.register_forward_hook(sequence_parallel_post_hook) - ) + if self.gather_outputs: + self.hook_handles.append( + model.register_forward_hook(sequence_parallel_post_hook) + ) def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor."""