add gather post hook, simplify, fixes

This commit is contained in:
Dan Saunders
2025-04-24 14:10:03 +00:00
parent cb7c3ee847
commit 072df89e0e
4 changed files with 115 additions and 21 deletions

View File

@@ -378,11 +378,6 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
# This is needed due to details of our sequence parallel implementation; the HF
# trainer averages the loss over the full sequence length depite our splitting
# the data along the sequence dimension.
loss *= self.args.sequence_parallel_degree
return loss
@staticmethod

View File

@@ -98,12 +98,13 @@ class SequenceParallelMixin:
)
class SequenceParallelContext:
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook.
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
@@ -122,28 +123,37 @@ class SequenceParallelContext:
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handle: RemovableHandle | None = None
self.hook_handles: list[RemovableHandle] = []
def __enter__(self):
# Define a forward pre-hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(
module, args, kwargs
): # pylint: disable=unused-argument
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(kwargs)
return args, kwargs
# Register the pre-forward hook on the model
self.hook_handle = self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove the forward pre-hook
self.hook_handle.remove()
self.hook_handle = None
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
@@ -199,3 +209,90 @@ class SequenceParallelContext:
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -26,7 +26,9 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -198,7 +200,7 @@ def execute_training(
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContext(
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,

View File

@@ -1161,7 +1161,7 @@ class AxolotlInputConfig(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if self.sample_packing and not self.micro_batch_size:
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"