gating _gather_outputs (causes increased vram usage) (#2829)

* SP vram fix

* gating _gather_outputs (causes increased vram usage)

* reverting unneeded change
This commit is contained in:
Dan Saunders
2025-06-25 08:33:55 -04:00
committed by GitHub
parent 46675496a3
commit 8c69ec3a1e
2 changed files with 11 additions and 4 deletions

View File

@@ -218,6 +218,7 @@ def execute_training(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
)
)

View File

@@ -174,6 +174,8 @@ class SequenceParallelContextManager:
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
gather_outputs: Whether to gather outputs after model forward pass across the
sequence parallel group.
"""
def __init__(
@@ -183,12 +185,15 @@ class SequenceParallelContextManager:
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
gather_outputs: bool,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
self.gather_outputs = gather_outputs
self._register_ring_attn()
# Set distributed info for local rank
@@ -277,16 +282,17 @@ class SequenceParallelContextManager:
return output
# Register both hooks
# Register hooks
for model in self.models:
self.hook_handles.append(
model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
if self.gather_outputs:
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""