add gather post hook, simplify, fixes
This commit is contained in:
@@ -378,11 +378,6 @@ class AxolotlTrainer(
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is needed due to details of our sequence parallel implementation; the HF
|
|
||||||
# trainer averages the loss over the full sequence length depite our splitting
|
|
||||||
# the data along the sequence dimension.
|
|
||||||
loss *= self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -98,12 +98,13 @@ class SequenceParallelMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelContext:
|
class SequenceParallelContextManager:
|
||||||
"""
|
"""
|
||||||
Context manager for sequence parallelism operations.
|
Context manager for sequence parallelism operations.
|
||||||
|
|
||||||
This class provides a context that will automatically apply sequence parallelism
|
This class provides a context that will automatically apply sequence parallelism
|
||||||
during model forward passes using a pre-forward hook.
|
during model forward passes using a pre-forward hook, and gather outputs from
|
||||||
|
across the sequence parallelism group using a post-forward hook.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -122,28 +123,37 @@ class SequenceParallelContext:
|
|||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
# Will store hook handles for removal
|
# Will store hook handles for removal
|
||||||
self.hook_handle: RemovableHandle | None = None
|
self.hook_handles: list[RemovableHandle] = []
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# Define a forward pre-hook to apply sequence parallelism with kwargs support
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
module, args, kwargs
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
# Apply sequence parallelism to kwargs
|
# Apply sequence parallelism to kwargs
|
||||||
kwargs = self.apply_sequence_parallelism(kwargs)
|
kwargs = self.apply_sequence_parallelism(kwargs)
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
# Register the pre-forward hook on the model
|
# Forward post-hook to gather outputs
|
||||||
self.hook_handle = self.model.register_forward_pre_hook(
|
def sequence_parallel_post_hook(_, __, output):
|
||||||
sequence_parallel_pre_hook, with_kwargs=True
|
# Gather the sharded outputs
|
||||||
|
return self.gather_outputs(output)
|
||||||
|
|
||||||
|
# Register both hooks
|
||||||
|
self.hook_handles.append(
|
||||||
|
self.model.register_forward_pre_hook(
|
||||||
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.hook_handles.append(
|
||||||
|
self.model.register_forward_hook(sequence_parallel_post_hook)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
# Remove the forward pre-hook
|
# Remove all hooks
|
||||||
self.hook_handle.remove()
|
for handle in self.hook_handles:
|
||||||
self.hook_handle = None
|
handle.remove()
|
||||||
|
self.hook_handles = []
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
def apply_sequence_parallelism(
|
||||||
self, batch: dict[str, torch.Tensor]
|
self, batch: dict[str, torch.Tensor]
|
||||||
@@ -199,3 +209,90 @@ class SequenceParallelContext:
|
|||||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def gather_outputs(self, output):
|
||||||
|
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||||
|
# Handle different output formats (dict, tensor, etc.)
|
||||||
|
if isinstance(output, dict):
|
||||||
|
gathered_output = {}
|
||||||
|
for key, value in output.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||||
|
# Gather logits or other sequence-sharded tensors
|
||||||
|
gathered_value = self.gather_tensor(value)
|
||||||
|
gathered_output[key] = gathered_value
|
||||||
|
else:
|
||||||
|
gathered_value = value.clone()
|
||||||
|
dist.all_reduce(
|
||||||
|
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
|
||||||
|
)
|
||||||
|
gathered_output[key] = gathered_value
|
||||||
|
return gathered_output
|
||||||
|
if isinstance(output, torch.Tensor):
|
||||||
|
return self.gather_tensor(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def gather_tensor(self, tensor):
|
||||||
|
"""Gather a sharded tensor from all ranks."""
|
||||||
|
# Prepare tensors for all_gather
|
||||||
|
world_size = self.local_world_size
|
||||||
|
|
||||||
|
# Create list to store tensors from all ranks
|
||||||
|
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
|
||||||
|
|
||||||
|
# All-gather operation
|
||||||
|
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
|
||||||
|
|
||||||
|
# Concatenate along sequence dimension (typically dim=1)
|
||||||
|
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
|
||||||
|
# Simple concatenation for standard sharding
|
||||||
|
return torch.cat(gathered_tensors, dim=1)
|
||||||
|
|
||||||
|
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||||
|
# Each rank has a pattern of (rank, world_size*2-rank-1)
|
||||||
|
reconstituted_tensors = [None] * (world_size * 2)
|
||||||
|
|
||||||
|
# First, split each gathered tensor into its two chunks
|
||||||
|
for rank, gathered_tensor in enumerate(gathered_tensors):
|
||||||
|
# Each tensor contains two chunks in the sequence dimension
|
||||||
|
chunk_size = gathered_tensor.size(1) // 2
|
||||||
|
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
|
||||||
|
|
||||||
|
# Place chunks in their original positions
|
||||||
|
reconstituted_tensors[rank] = chunk1
|
||||||
|
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
|
||||||
|
|
||||||
|
# Concatenate the reconstituted tensors in the correct order
|
||||||
|
return torch.cat(reconstituted_tensors, dim=1)
|
||||||
|
|
||||||
|
# Otherwise, RingAttnFunc.BATCH_STRIPE
|
||||||
|
# In striping, each rank has every world_size-th slice
|
||||||
|
batch_size = tensor.size(0)
|
||||||
|
hidden_dim = tensor.size(-1)
|
||||||
|
|
||||||
|
# First, determine the full sequence length
|
||||||
|
total_seq_len = 0
|
||||||
|
for t in gathered_tensors:
|
||||||
|
total_seq_len += t.size(1)
|
||||||
|
|
||||||
|
# Create a tensor to hold the unstriped result
|
||||||
|
result = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
total_seq_len,
|
||||||
|
hidden_dim,
|
||||||
|
dtype=tensor.dtype,
|
||||||
|
device=tensor.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each rank's tensor, distribute its slices to the correct positions
|
||||||
|
for rank, gathered_tensor in enumerate(gathered_tensors):
|
||||||
|
# The rank's tensor contains every world_size-th slice
|
||||||
|
# starting from its rank position
|
||||||
|
seq_len = gathered_tensor.size(1)
|
||||||
|
for i in range(seq_len):
|
||||||
|
# Calculate the position in the full tensor
|
||||||
|
pos = i * world_size + rank
|
||||||
|
if pos < total_seq_len:
|
||||||
|
result[:, pos] = gathered_tensor[:, i]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
|||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext
|
from axolotl.core.trainers.mixins.sequence_parallel import (
|
||||||
|
SequenceParallelContextManager,
|
||||||
|
)
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
@@ -198,7 +200,7 @@ def execute_training(
|
|||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
sequence_parallel_context = (
|
sequence_parallel_context = (
|
||||||
SequenceParallelContext(
|
SequenceParallelContextManager(
|
||||||
model=trainer.model,
|
model=trainer.model,
|
||||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
|
|||||||
@@ -1161,7 +1161,7 @@ class AxolotlInputConfig(
|
|||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and not self.micro_batch_size:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||||
"due to a `ring-flash-attn` requirement"
|
"due to a `ring-flash-attn` requirement"
|
||||||
|
|||||||
Reference in New Issue
Block a user