This commit is contained in:
Dan Saunders
2025-04-23 23:19:52 +00:00
parent cafda804ec
commit 69aeae80ed
4 changed files with 112 additions and 274 deletions

View File

@@ -371,13 +371,20 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
return super().compute_loss( loss = super().compute_loss(
model, model,
inputs, inputs,
return_outputs=return_outputs, return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch, num_items_in_batch=num_items_in_batch,
) )
# This is needed due to details of our sequence parallel implementation; the HF
# trainer averages the loss over the full sequence length depite our splitting
# the data along the sequence dimension.
loss *= self.args.sequence_parallel_degree
return loss
@staticmethod @staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {} concatenated_batch = {}

View File

@@ -1,11 +1,11 @@
import contextlib """Module for definition of sequence parallel context manager"""
import functools
import inspect
import logging import logging
from typing import Dict, List, Optional, Set
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn from torch import nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn.patch import ( from axolotl.monkeypatch.attention.ring_attn.patch import (
@@ -22,17 +22,16 @@ class SequenceParallelContext:
Context manager for sequence parallelism operations. Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism This class provides a context that will automatically apply sequence parallelism
during model forward passes. during model forward passes using pre-forward hooks.
""" """
# Keep track of active contexts to support nested contexts
_active_contexts = []
def __init__( def __init__(
self, self,
model: nn.Module,
sequence_parallel_degree: int, sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc, ring_attn_func: RingAttnFunc,
): ):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group() self.process_group = get_ring_attn_group()
@@ -42,9 +41,8 @@ class SequenceParallelContext:
self.local_world_size = 1 self.local_world_size = 1
self.active = False self.active = False
# Will store original methods for restoration # Will store hook handles for removal
self._original_module_forward = None self.hook_handles: list[RemovableHandle] = []
self._hooks: List[RemovableHandle] = []
if self.sequence_parallel_degree > 1: if self.sequence_parallel_degree > 1:
if self.process_group is None: if self.process_group is None:
@@ -55,75 +53,64 @@ class SequenceParallelContext:
def __enter__(self): def __enter__(self):
self.active = True self.active = True
SequenceParallelContext._active_contexts.append(self)
# Store the original forward method # Define a pre-forward hook to apply sequence parallelism with kwargs support
if self._original_module_forward is None: def sequence_parallel_pre_hook(module, args, kwargs):
self._original_module_forward = nn.Module.forward if not self.active or self.sequence_parallel_degree <= 1:
return None
# Replace nn.Module.forward with our sequence parallel version # Apply sequence parallelism to kwargs
nn.Module.forward = self._make_sequence_parallel_forward(nn.Module.forward) if kwargs:
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
return args, transformed_kwargs
# If no kwargs but we have args, try to convert them to kwargs
if args and not kwargs:
try:
signature = inspect.signature(module.forward)
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
# Create kwargs from args
new_kwargs = {}
for i, arg in enumerate(args):
if i < len(param_names):
new_kwargs[param_names[i]] = arg
else:
# If we can't map all args, don't transform
return None
# Apply sequence parallelism to the new kwargs
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
# Return empty args and the transformed kwargs
return (), transformed_kwargs
except (ValueError, TypeError):
# If conversion fails, don't transform
return None
# If no args and no kwargs, nothing to transform
return None
# Register the pre-forward hook on the model
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.active = False self.active = False
# Only restore original forward if this is the last active context # Remove all hooks
if ( for handle in self.hook_handles:
SequenceParallelContext._active_contexts handle.remove()
and SequenceParallelContext._active_contexts[-1] == self self.hook_handles = []
):
SequenceParallelContext._active_contexts.pop()
# Restore original forward method
if self._original_module_forward is not None:
nn.Module.forward = self._original_module_forward
self._original_module_forward = None
# Remove any hooks we added
for hook in self._hooks:
hook.remove()
self._hooks = []
def _make_sequence_parallel_forward(self, original_forward):
"""Create a wrapped forward method that applies sequence parallelism."""
@functools.wraps(original_forward)
def sequence_parallel_forward(module_self, *args, **kwargs):
# Skip sequence parallelism if not active or degree is 1
if not self.active or self.sequence_parallel_degree <= 1:
return original_forward(module_self, *args, **kwargs)
# Convert args to kwargs if needed
if args:
# Try to map positional args to kwargs based on the forward method signature
import inspect
try:
signature = inspect.signature(original_forward)
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
for i, arg in enumerate(args):
if i < len(param_names):
kwargs[param_names[i]] = arg
else:
# If we can't map all args, fall back to original forward
return original_forward(module_self, *args, **kwargs)
except (ValueError, TypeError):
# If we can't get the signature, just use the original forward
return original_forward(module_self, *args, **kwargs)
# Apply sequence parallelism to the inputs
kwargs = self.apply_sequence_parallelism(kwargs)
# Call the original forward with modified kwargs
return original_forward(module_self, **kwargs)
return sequence_parallel_forward
def apply_sequence_parallelism( def apply_sequence_parallelism(
self, batch: Dict[str, torch.Tensor] self, batch: dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Apply sequence parallelism slicing to a batch. Apply sequence parallelism slicing to a batch.
@@ -136,198 +123,45 @@ class SequenceParallelContext:
if self.sequence_parallel_degree <= 1 or not self.active: if self.sequence_parallel_degree <= 1 or not self.active:
return batch return batch
# Make a copy of the batch to avoid modifying the original
new_batch = dict(batch)
# Get total sequence length from input_ids or inputs_embeds
if "input_ids" in new_batch:
total_seq_len = new_batch["input_ids"].size(1)
elif "inputs_embeds" in new_batch:
total_seq_len = new_batch["inputs_embeds"].size(1)
else:
# If we can't determine sequence length, return the batch as is
return new_batch
# Update ring attention params if needed # Update ring attention params if needed
if new_batch.get("position_ids") is not None: if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=new_batch["position_ids"]) update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing # Slice batch for sequence parallel processing
for key in new_batch: total_seq_len = batch["input_ids"].size(1)
for key in batch:
if ( if (
key in new_batch key in batch
and isinstance(new_batch[key], torch.Tensor) and isinstance(batch[key], torch.Tensor)
and new_batch[key].dim() > 1 and batch[key].dim() > 1
and new_batch[key].size(1) == total_seq_len and batch[key].size(1) == total_seq_len
): ):
if self.ring_attn_func in [ if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING, RingAttnFunc.BATCH_RING,
]: ]:
new_batch[key] = ( # Split in sequential fashion and grab this rank's chunk
new_batch[key] batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank] .chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous() .contiguous()
) )
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = new_batch[key].chunk(2 * self.local_world_size, dim=1) chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern # Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [ selected_chunks = [
chunks[self.local_rank], chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1], chunks[2 * self.local_world_size - self.local_rank - 1],
] ]
new_batch[key] = torch.cat(selected_chunks, dim=1).contiguous() batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack # Split into striped data and stack
tensor = torch.stack( tensor = torch.stack(
new_batch[key].split(self.local_world_size, dim=1), batch[key].split(self.local_world_size, dim=1),
dim=1, dim=1,
).transpose(1, 2) ).transpose(1, 2)
new_batch[key] = tensor[:, self.local_rank].contiguous() batch[key] = tensor[:, self.local_rank].contiguous()
return new_batch return batch
@contextlib.contextmanager
def sequence_parallel(
sequence_parallel_degree: int = 1,
process_group: Optional[dist.ProcessGroup] = None,
ring_attn_func: Optional[RingAttnFunc] = None,
buffers: Optional[List[torch.Tensor]] = None,
buffer_seq_dims: Optional[List[int]] = None,
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
):
"""
Context manager for sequence parallelism.
This context manager will apply sequence parallelism to model inputs
for all forward passes within its scope.
Args:
sequence_parallel_degree: The degree of sequence parallelism.
process_group: The process group to use for communication. Default is the world group.
ring_attn_func: The ring attention function to use.
buffers: Optional list of buffers to shard (e.g., input tensors, position embeddings).
buffer_seq_dims: Sequence dimensions for each buffer to shard.
no_restore_buffers: Optional set of buffers that don't need to be restored.
Yields:
The sequence parallel context.
"""
buffers = [] if buffers is None else buffers
buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
if len(buffers) != len(buffer_seq_dims):
raise ValueError(
"`buffer_seq_dims` must have the same number of elements as `buffers`."
)
# Save original buffer states
original_buffers = []
for buffer in buffers:
if buffer in no_restore_buffers:
original_buffers.append(None)
else:
original_buffers.append(buffer.clone())
# Create context
context = SequenceParallelContext(
sequence_parallel_degree=sequence_parallel_degree,
process_group=process_group,
ring_attn_func=ring_attn_func,
)
# Apply sequence parallelism to buffers if provided
if buffers and buffer_seq_dims:
for i, (buffer, dim) in enumerate(zip(buffers, buffer_seq_dims)):
if context.sequence_parallel_degree > 1:
# Get local shard
sharded_tensor = context.apply_sequence_parallelism(
{"tensor": buffer.unsqueeze(0)}
)["tensor"].squeeze(0)
# Resize and copy in-place
buffer.resize_(sharded_tensor.shape)
buffer.copy_(sharded_tensor)
try:
# Enter the context
with context:
yield context
finally:
# Restore original buffer states
for buffer, original in zip(buffers, original_buffers):
if original is not None:
buffer.resize_(original.shape)
buffer.copy_(original)
def enable_sequence_parallel_for_module(
module: nn.Module,
sequence_parallel_degree: int = 1,
process_group: Optional[dist.ProcessGroup] = None,
ring_attn_func: Optional[RingAttnFunc] = None,
):
"""
Enable sequence parallelism for a specific module.
This function wraps the module's forward method to automatically apply
sequence parallelism without using a context manager.
Args:
module: The module to enable sequence parallelism for.
sequence_parallel_degree: The degree of sequence parallelism.
process_group: The process group to use for communication.
ring_attn_func: The ring attention function to use.
Returns:
The module with sequence parallelism enabled.
"""
# Create a context for this module
context = SequenceParallelContext(
sequence_parallel_degree=sequence_parallel_degree,
process_group=process_group,
ring_attn_func=ring_attn_func,
)
# Save the original forward method
original_forward = module.forward
@functools.wraps(original_forward)
def sequence_parallel_forward(*args, **kwargs):
# Activate the context
context.active = True
# Convert args to kwargs if needed
if args:
import inspect
try:
signature = inspect.signature(original_forward)
param_names = list(signature.parameters.keys())
for i, arg in enumerate(args):
if i < len(param_names):
kwargs[param_names[i]] = arg
else:
return original_forward(*args, **kwargs)
except (ValueError, TypeError):
return original_forward(*args, **kwargs)
# Apply sequence parallelism to inputs
kwargs = context.apply_sequence_parallelism(kwargs)
# Call original forward with modified inputs
result = original_forward(**kwargs)
# Deactivate the context
context.active = False
return result
# Replace the forward method
module.forward = sequence_parallel_forward
return module

View File

@@ -197,6 +197,7 @@ def execute_training(
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
elif cfg.sequence_parallel_degree > 1: elif cfg.sequence_parallel_degree > 1:
with SequenceParallelContext( with SequenceParallelContext(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree, sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func, ring_attn_func=cfg.ring_attn_func,
): ):

View File

@@ -18,6 +18,7 @@ from pydantic import (
) )
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
DatasetConfig, DatasetConfig,
DPODataset, DPODataset,
@@ -1149,22 +1150,17 @@ class AxolotlInputConfig(
return data return data
@field_validator("sequence_parallel_degree", mode="after") @model_validator(mode="after")
@classmethod def check_sequence_parallel_degree(self):
def check_sequence_parallel_degree(cls, value, info): if not self.sequence_parallel_degree:
if not value: self.sequence_parallel_degree = 1
value = 1 elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
if value > 1:
if not info.data.get("flash_attention"):
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
if ( if self.sample_packing and not self.micro_batch_size:
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
raise ValueError( raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled" "micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
@@ -1182,44 +1178,44 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks # TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled # / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank. # according to the proportion of non-padding tokens per rank.
LOG.warning( if is_main_process():
"Sequence parallelism (SP) is enabled with " LOG.warning(
f"sequence_parallel_degree={value}. Please note that logged losses may " "Sequence parallelism (SP) is enabled with "
"differ slightly to the non-SP losses due to transformers Trainer " f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"implementation details. Please see " "Please note that logged losses may differ slightly to the non-SP "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " "losses due to transformers Trainer implementation details. "
"for more details." "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
) "for more details."
)
return value return self
@field_validator("ring_attn_func", mode="after") @model_validator(mode="after")
@classmethod def validate_ring_attn_func(self):
def check_ring_attn_func(cls, value, info): if self.sequence_parallel_degree == 1:
if not info.data.get("sequence_parallel_degree", 1) > 1: return self
return value
# Your validation logic for ring_attn_func
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None: if self.ring_attn_func is not None:
# Set the ring attention function if passed in config
valid_funcs = list(RingAttnFunc) valid_funcs = list(RingAttnFunc)
if value in valid_funcs: if self.ring_attn_func in valid_funcs:
value = RingAttnFunc(value) self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else: else:
raise ValueError( raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}" f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
) )
else: else:
# Default ring attention function selection # Default ring attention function selection
sample_packing = info.data.get("sample_packing") sample_packing = getattr(self, "sample_packing", False)
value = ( self.ring_attn_func = (
RingAttnFunc.VARLEN_LLAMA3 RingAttnFunc.VARLEN_LLAMA3
if sample_packing if sample_packing
else RingAttnFunc.BATCH_RING else RingAttnFunc.BATCH_RING
) )
return value return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod