gating _gather_outputs (causes increased vram usage) (#2829)
* SP vram fix * gating _gather_outputs (causes increased vram usage) * reverting unneeded change
This commit is contained in:
@@ -218,6 +218,7 @@ def execute_training(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -174,6 +174,8 @@ class SequenceParallelContextManager:
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
gather_outputs: Whether to gather outputs after model forward pass across the
|
||||
sequence parallel group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -183,12 +185,15 @@ class SequenceParallelContextManager:
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self.gather_outputs = gather_outputs
|
||||
|
||||
self._register_ring_attn()
|
||||
|
||||
# Set distributed info for local rank
|
||||
@@ -277,16 +282,17 @@ class SequenceParallelContextManager:
|
||||
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
# Register hooks
|
||||
for model in self.models:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
if self.gather_outputs:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
|
||||
Reference in New Issue
Block a user