This commit is contained in:
Dan Saunders
2025-04-23 23:19:52 +00:00
parent cafda804ec
commit 69aeae80ed
4 changed files with 112 additions and 274 deletions

View File

@@ -371,13 +371,20 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
loss = super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
# This is needed due to details of our sequence parallel implementation; the HF
# trainer averages the loss over the full sequence length depite our splitting
# the data along the sequence dimension.
loss *= self.args.sequence_parallel_degree
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -1,11 +1,11 @@
import contextlib
import functools
"""Module for definition of sequence parallel context manager"""
import inspect
import logging
from typing import Dict, List, Optional, Set
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import nn
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn.patch import (
@@ -22,17 +22,16 @@ class SequenceParallelContext:
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes.
during model forward passes using pre-forward hooks.
"""
# Keep track of active contexts to support nested contexts
_active_contexts = []
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
@@ -42,9 +41,8 @@ class SequenceParallelContext:
self.local_world_size = 1
self.active = False
# Will store original methods for restoration
self._original_module_forward = None
self._hooks: List[RemovableHandle] = []
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
if self.sequence_parallel_degree > 1:
if self.process_group is None:
@@ -55,75 +53,64 @@ class SequenceParallelContext:
def __enter__(self):
self.active = True
SequenceParallelContext._active_contexts.append(self)
# Store the original forward method
if self._original_module_forward is None:
self._original_module_forward = nn.Module.forward
# Define a pre-forward hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(module, args, kwargs):
if not self.active or self.sequence_parallel_degree <= 1:
return None
# Replace nn.Module.forward with our sequence parallel version
nn.Module.forward = self._make_sequence_parallel_forward(nn.Module.forward)
# Apply sequence parallelism to kwargs
if kwargs:
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
return args, transformed_kwargs
# If no kwargs but we have args, try to convert them to kwargs
if args and not kwargs:
try:
signature = inspect.signature(module.forward)
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
# Create kwargs from args
new_kwargs = {}
for i, arg in enumerate(args):
if i < len(param_names):
new_kwargs[param_names[i]] = arg
else:
# If we can't map all args, don't transform
return None
# Apply sequence parallelism to the new kwargs
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
# Return empty args and the transformed kwargs
return (), transformed_kwargs
except (ValueError, TypeError):
# If conversion fails, don't transform
return None
# If no args and no kwargs, nothing to transform
return None
# Register the pre-forward hook on the model
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.active = False
# Only restore original forward if this is the last active context
if (
SequenceParallelContext._active_contexts
and SequenceParallelContext._active_contexts[-1] == self
):
SequenceParallelContext._active_contexts.pop()
# Restore original forward method
if self._original_module_forward is not None:
nn.Module.forward = self._original_module_forward
self._original_module_forward = None
# Remove any hooks we added
for hook in self._hooks:
hook.remove()
self._hooks = []
def _make_sequence_parallel_forward(self, original_forward):
"""Create a wrapped forward method that applies sequence parallelism."""
@functools.wraps(original_forward)
def sequence_parallel_forward(module_self, *args, **kwargs):
# Skip sequence parallelism if not active or degree is 1
if not self.active or self.sequence_parallel_degree <= 1:
return original_forward(module_self, *args, **kwargs)
# Convert args to kwargs if needed
if args:
# Try to map positional args to kwargs based on the forward method signature
import inspect
try:
signature = inspect.signature(original_forward)
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
for i, arg in enumerate(args):
if i < len(param_names):
kwargs[param_names[i]] = arg
else:
# If we can't map all args, fall back to original forward
return original_forward(module_self, *args, **kwargs)
except (ValueError, TypeError):
# If we can't get the signature, just use the original forward
return original_forward(module_self, *args, **kwargs)
# Apply sequence parallelism to the inputs
kwargs = self.apply_sequence_parallelism(kwargs)
# Call the original forward with modified kwargs
return original_forward(module_self, **kwargs)
return sequence_parallel_forward
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def apply_sequence_parallelism(
self, batch: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
self, batch: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
@@ -136,198 +123,45 @@ class SequenceParallelContext:
if self.sequence_parallel_degree <= 1 or not self.active:
return batch
# Make a copy of the batch to avoid modifying the original
new_batch = dict(batch)
# Get total sequence length from input_ids or inputs_embeds
if "input_ids" in new_batch:
total_seq_len = new_batch["input_ids"].size(1)
elif "inputs_embeds" in new_batch:
total_seq_len = new_batch["inputs_embeds"].size(1)
else:
# If we can't determine sequence length, return the batch as is
return new_batch
# Update ring attention params if needed
if new_batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=new_batch["position_ids"])
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in new_batch:
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in new_batch
and isinstance(new_batch[key], torch.Tensor)
and new_batch[key].dim() > 1
and new_batch[key].size(1) == total_seq_len
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
new_batch[key] = (
new_batch[key]
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = new_batch[key].chunk(2 * self.local_world_size, dim=1)
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
new_batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
new_batch[key].split(self.local_world_size, dim=1),
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
new_batch[key] = tensor[:, self.local_rank].contiguous()
batch[key] = tensor[:, self.local_rank].contiguous()
return new_batch
@contextlib.contextmanager
def sequence_parallel(
sequence_parallel_degree: int = 1,
process_group: Optional[dist.ProcessGroup] = None,
ring_attn_func: Optional[RingAttnFunc] = None,
buffers: Optional[List[torch.Tensor]] = None,
buffer_seq_dims: Optional[List[int]] = None,
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
):
"""
Context manager for sequence parallelism.
This context manager will apply sequence parallelism to model inputs
for all forward passes within its scope.
Args:
sequence_parallel_degree: The degree of sequence parallelism.
process_group: The process group to use for communication. Default is the world group.
ring_attn_func: The ring attention function to use.
buffers: Optional list of buffers to shard (e.g., input tensors, position embeddings).
buffer_seq_dims: Sequence dimensions for each buffer to shard.
no_restore_buffers: Optional set of buffers that don't need to be restored.
Yields:
The sequence parallel context.
"""
buffers = [] if buffers is None else buffers
buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
if len(buffers) != len(buffer_seq_dims):
raise ValueError(
"`buffer_seq_dims` must have the same number of elements as `buffers`."
)
# Save original buffer states
original_buffers = []
for buffer in buffers:
if buffer in no_restore_buffers:
original_buffers.append(None)
else:
original_buffers.append(buffer.clone())
# Create context
context = SequenceParallelContext(
sequence_parallel_degree=sequence_parallel_degree,
process_group=process_group,
ring_attn_func=ring_attn_func,
)
# Apply sequence parallelism to buffers if provided
if buffers and buffer_seq_dims:
for i, (buffer, dim) in enumerate(zip(buffers, buffer_seq_dims)):
if context.sequence_parallel_degree > 1:
# Get local shard
sharded_tensor = context.apply_sequence_parallelism(
{"tensor": buffer.unsqueeze(0)}
)["tensor"].squeeze(0)
# Resize and copy in-place
buffer.resize_(sharded_tensor.shape)
buffer.copy_(sharded_tensor)
try:
# Enter the context
with context:
yield context
finally:
# Restore original buffer states
for buffer, original in zip(buffers, original_buffers):
if original is not None:
buffer.resize_(original.shape)
buffer.copy_(original)
def enable_sequence_parallel_for_module(
module: nn.Module,
sequence_parallel_degree: int = 1,
process_group: Optional[dist.ProcessGroup] = None,
ring_attn_func: Optional[RingAttnFunc] = None,
):
"""
Enable sequence parallelism for a specific module.
This function wraps the module's forward method to automatically apply
sequence parallelism without using a context manager.
Args:
module: The module to enable sequence parallelism for.
sequence_parallel_degree: The degree of sequence parallelism.
process_group: The process group to use for communication.
ring_attn_func: The ring attention function to use.
Returns:
The module with sequence parallelism enabled.
"""
# Create a context for this module
context = SequenceParallelContext(
sequence_parallel_degree=sequence_parallel_degree,
process_group=process_group,
ring_attn_func=ring_attn_func,
)
# Save the original forward method
original_forward = module.forward
@functools.wraps(original_forward)
def sequence_parallel_forward(*args, **kwargs):
# Activate the context
context.active = True
# Convert args to kwargs if needed
if args:
import inspect
try:
signature = inspect.signature(original_forward)
param_names = list(signature.parameters.keys())
for i, arg in enumerate(args):
if i < len(param_names):
kwargs[param_names[i]] = arg
else:
return original_forward(*args, **kwargs)
except (ValueError, TypeError):
return original_forward(*args, **kwargs)
# Apply sequence parallelism to inputs
kwargs = context.apply_sequence_parallelism(kwargs)
# Call original forward with modified inputs
result = original_forward(**kwargs)
# Deactivate the context
context.active = False
return result
# Replace the forward method
module.forward = sequence_parallel_forward
return module
return batch

View File

@@ -197,6 +197,7 @@ def execute_training(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
elif cfg.sequence_parallel_degree > 1:
with SequenceParallelContext(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
):

View File

@@ -18,6 +18,7 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import is_main_process
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
@@ -1149,22 +1150,17 @@ class AxolotlInputConfig(
return data
@field_validator("sequence_parallel_degree", mode="after")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
if self.sample_packing and not self.micro_batch_size:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"
@@ -1182,44 +1178,44 @@ class AxolotlInputConfig(
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
if is_main_process():
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return value
return self
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
@model_validator(mode="after")
def validate_ring_attn_func(self):
if self.sequence_parallel_degree == 1:
return self
# Your validation logic for ring_attn_func
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
sample_packing = getattr(self, "sample_packing", False)
self.ring_attn_func = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
return self
@model_validator(mode="before")
@classmethod