gating _gather_outputs (causes increased vram usage) (#2829)

* SP vram fix

* gating _gather_outputs (causes increased vram usage)

* reverting unneeded change
This commit is contained in:
Dan Saunders
2025-06-25 08:33:55 -04:00
committed by GitHub
parent 46675496a3
commit 8c69ec3a1e
2 changed files with 11 additions and 4 deletions

View File

@@ -218,6 +218,7 @@ def execute_training(
gradient_accumulation_steps=cfg.gradient_accumulation_steps, gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func, ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride, heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
) )
) )

View File

@@ -174,6 +174,8 @@ class SequenceParallelContextManager:
ring_attn_func: Which ring attention function to use. Currently unused. ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation. `varlen_llama3` `ring_flash_attn` implementation.
gather_outputs: Whether to gather outputs after model forward pass across the
sequence parallel group.
""" """
def __init__( def __init__(
@@ -183,12 +185,15 @@ class SequenceParallelContextManager:
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, ring_attn_func: RingAttnFunc,
heads_k_stride: int | None, heads_k_stride: int | None,
gather_outputs: bool,
): ):
self.models = models self.models = models
self.sequence_parallel_degree = sequence_parallel_degree self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride self.heads_k_stride = heads_k_stride
self.gather_outputs = gather_outputs
self._register_ring_attn() self._register_ring_attn()
# Set distributed info for local rank # Set distributed info for local rank
@@ -277,16 +282,17 @@ class SequenceParallelContextManager:
return output return output
# Register both hooks # Register hooks
for model in self.models: for model in self.models:
self.hook_handles.append( self.hook_handles.append(
model.register_forward_pre_hook( model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True sequence_parallel_pre_hook, with_kwargs=True
) )
) )
self.hook_handles.append( if self.gather_outputs:
model.register_forward_hook(sequence_parallel_post_hook) self.hook_handles.append(
) model.register_forward_hook(sequence_parallel_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor.""" """Gather sharded outputs from all ranks and reconstruct the full tensor."""