gating _gather_outputs (causes increased vram usage) (#2829)
* SP vram fix * gating _gather_outputs (causes increased vram usage) * reverting unneeded change
This commit is contained in:
@@ -218,6 +218,7 @@ def execute_training(
|
|||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
gather_outputs=cfg.rl is RLType.GRPO,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -174,6 +174,8 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
`varlen_llama3` `ring_flash_attn` implementation.
|
`varlen_llama3` `ring_flash_attn` implementation.
|
||||||
|
gather_outputs: Whether to gather outputs after model forward pass across the
|
||||||
|
sequence parallel group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -183,12 +185,15 @@ class SequenceParallelContextManager:
|
|||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
|
gather_outputs: bool,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
|
self.gather_outputs = gather_outputs
|
||||||
|
|
||||||
self._register_ring_attn()
|
self._register_ring_attn()
|
||||||
|
|
||||||
# Set distributed info for local rank
|
# Set distributed info for local rank
|
||||||
@@ -277,16 +282,17 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# Register both hooks
|
# Register hooks
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
model.register_forward_pre_hook(
|
model.register_forward_pre_hook(
|
||||||
sequence_parallel_pre_hook, with_kwargs=True
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.hook_handles.append(
|
if self.gather_outputs:
|
||||||
model.register_forward_hook(sequence_parallel_post_hook)
|
self.hook_handles.append(
|
||||||
)
|
model.register_forward_hook(sequence_parallel_post_hook)
|
||||||
|
)
|
||||||
|
|
||||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||||
|
|||||||
Reference in New Issue
Block a user