diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 54fc5d902..3864903a5 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -378,11 +378,6 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) - # This is needed due to details of our sequence parallel implementation; the HF - # trainer averages the loss over the full sequence length depite our splitting - # the data along the sequence dimension. - loss *= self.args.sequence_parallel_degree - return loss @staticmethod diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index d53f90ade..b5101e035 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -98,12 +98,13 @@ class SequenceParallelMixin: ) -class SequenceParallelContext: +class SequenceParallelContextManager: """ Context manager for sequence parallelism operations. This class provides a context that will automatically apply sequence parallelism - during model forward passes using a pre-forward hook. + during model forward passes using a pre-forward hook, and gather outputs from + across the sequence parallelism group using a post-forward hook. """ def __init__( @@ -122,28 +123,37 @@ class SequenceParallelContext: self.local_world_size = dist.get_world_size(self.process_group) # Will store hook handles for removal - self.hook_handle: RemovableHandle | None = None + self.hook_handles: list[RemovableHandle] = [] def __enter__(self): - # Define a forward pre-hook to apply sequence parallelism with kwargs support - def sequence_parallel_pre_hook( - module, args, kwargs - ): # pylint: disable=unused-argument + # Forward pre-hook to apply sequence parallelism + def sequence_parallel_pre_hook(_, args, kwargs): # Apply sequence parallelism to kwargs kwargs = self.apply_sequence_parallelism(kwargs) return args, kwargs - # Register the pre-forward hook on the model - self.hook_handle = self.model.register_forward_pre_hook( - sequence_parallel_pre_hook, with_kwargs=True + # Forward post-hook to gather outputs + def sequence_parallel_post_hook(_, __, output): + # Gather the sharded outputs + return self.gather_outputs(output) + + # Register both hooks + self.hook_handles.append( + self.model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True + ) + ) + self.hook_handles.append( + self.model.register_forward_hook(sequence_parallel_post_hook) ) return self def __exit__(self, exc_type, exc_val, exc_tb): - # Remove the forward pre-hook - self.hook_handle.remove() - self.hook_handle = None + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] def apply_sequence_parallelism( self, batch: dict[str, torch.Tensor] @@ -199,3 +209,90 @@ class SequenceParallelContext: batch[key] = tensor[:, self.local_rank].contiguous() return batch + + def gather_outputs(self, output): + """Gather sharded outputs from all ranks and reconstruct the full tensor.""" + # Handle different output formats (dict, tensor, etc.) + if isinstance(output, dict): + gathered_output = {} + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + # Gather logits or other sequence-sharded tensors + gathered_value = self.gather_tensor(value) + gathered_output[key] = gathered_value + else: + gathered_value = value.clone() + dist.all_reduce( + gathered_value, op=dist.ReduceOp.SUM, group=self.process_group + ) + gathered_output[key] = gathered_value + return gathered_output + if isinstance(output, torch.Tensor): + return self.gather_tensor(output) + + return output + + def gather_tensor(self, tensor): + """Gather a sharded tensor from all ranks.""" + # Prepare tensors for all_gather + world_size = self.local_world_size + + # Create list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] + + # All-gather operation + dist.all_gather(gathered_tensors, tensor, group=self.process_group) + + # Concatenate along sequence dimension (typically dim=1) + if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]: + # Simple concatenation for standard sharding + return torch.cat(gathered_tensors, dim=1) + + if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + # Each rank has a pattern of (rank, world_size*2-rank-1) + reconstituted_tensors = [None] * (world_size * 2) + + # First, split each gathered tensor into its two chunks + for rank, gathered_tensor in enumerate(gathered_tensors): + # Each tensor contains two chunks in the sequence dimension + chunk_size = gathered_tensor.size(1) // 2 + chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1) + + # Place chunks in their original positions + reconstituted_tensors[rank] = chunk1 + reconstituted_tensors[world_size * 2 - rank - 1] = chunk2 + + # Concatenate the reconstituted tensors in the correct order + return torch.cat(reconstituted_tensors, dim=1) + + # Otherwise, RingAttnFunc.BATCH_STRIPE + # In striping, each rank has every world_size-th slice + batch_size = tensor.size(0) + hidden_dim = tensor.size(-1) + + # First, determine the full sequence length + total_seq_len = 0 + for t in gathered_tensors: + total_seq_len += t.size(1) + + # Create a tensor to hold the unstriped result + result = torch.zeros( + batch_size, + total_seq_len, + hidden_dim, + dtype=tensor.dtype, + device=tensor.device, + ) + + # For each rank's tensor, distribute its slices to the correct positions + for rank, gathered_tensor in enumerate(gathered_tensors): + # The rank's tensor contains every world_size-th slice + # starting from its rank position + seq_len = gathered_tensor.size(1) + for i in range(seq_len): + # Calculate the position in the full tensor + pos = i * world_size + rank + if pos < total_seq_len: + result[:, pos] = gathered_tensor[:, i] + + return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 02a69c909..d116ea4fd 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -26,7 +26,9 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext +from axolotl.core.trainers.mixins.sequence_parallel import ( + SequenceParallelContextManager, +) from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -198,7 +200,7 @@ def execute_training( else nullcontext() ) sequence_parallel_context = ( - SequenceParallelContext( + SequenceParallelContextManager( model=trainer.model, sequence_parallel_degree=cfg.sequence_parallel_degree, ring_attn_func=cfg.ring_attn_func, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b749bebc2..f68d160df 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1161,7 +1161,7 @@ class AxolotlInputConfig( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - if self.sample_packing and not self.micro_batch_size: + if self.sample_packing and self.micro_batch_size > 1: raise ValueError( "micro_batch_size must be set to 1 when sample_packing is enabled" "due to a `ring-flash-attn` requirement"