updates
This commit is contained in:
@@ -371,13 +371,20 @@ class AxolotlTrainer(
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return super().compute_loss(
|
||||
loss = super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
# This is needed due to details of our sequence parallel implementation; the HF
|
||||
# trainer averages the loss over the full sequence length depite our splitting
|
||||
# the data along the sequence dimension.
|
||||
loss *= self.args.sequence_parallel_degree
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import contextlib
|
||||
import functools
|
||||
"""Module for definition of sequence parallel context manager"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
||||
@@ -22,17 +22,16 @@ class SequenceParallelContext:
|
||||
Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes.
|
||||
during model forward passes using pre-forward hooks.
|
||||
"""
|
||||
|
||||
# Keep track of active contexts to support nested contexts
|
||||
_active_contexts = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
sequence_parallel_degree: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
):
|
||||
self.model = model
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.process_group = get_ring_attn_group()
|
||||
@@ -42,9 +41,8 @@ class SequenceParallelContext:
|
||||
self.local_world_size = 1
|
||||
self.active = False
|
||||
|
||||
# Will store original methods for restoration
|
||||
self._original_module_forward = None
|
||||
self._hooks: List[RemovableHandle] = []
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
if self.sequence_parallel_degree > 1:
|
||||
if self.process_group is None:
|
||||
@@ -55,75 +53,64 @@ class SequenceParallelContext:
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
SequenceParallelContext._active_contexts.append(self)
|
||||
|
||||
# Store the original forward method
|
||||
if self._original_module_forward is None:
|
||||
self._original_module_forward = nn.Module.forward
|
||||
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
||||
def sequence_parallel_pre_hook(module, args, kwargs):
|
||||
if not self.active or self.sequence_parallel_degree <= 1:
|
||||
return None
|
||||
|
||||
# Replace nn.Module.forward with our sequence parallel version
|
||||
nn.Module.forward = self._make_sequence_parallel_forward(nn.Module.forward)
|
||||
# Apply sequence parallelism to kwargs
|
||||
if kwargs:
|
||||
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
||||
return args, transformed_kwargs
|
||||
|
||||
# If no kwargs but we have args, try to convert them to kwargs
|
||||
if args and not kwargs:
|
||||
try:
|
||||
signature = inspect.signature(module.forward)
|
||||
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
||||
|
||||
# Create kwargs from args
|
||||
new_kwargs = {}
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(param_names):
|
||||
new_kwargs[param_names[i]] = arg
|
||||
else:
|
||||
# If we can't map all args, don't transform
|
||||
return None
|
||||
|
||||
# Apply sequence parallelism to the new kwargs
|
||||
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
|
||||
|
||||
# Return empty args and the transformed kwargs
|
||||
return (), transformed_kwargs
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, don't transform
|
||||
return None
|
||||
|
||||
# If no args and no kwargs, nothing to transform
|
||||
return None
|
||||
|
||||
# Register the pre-forward hook on the model
|
||||
self.hook_handles.append(
|
||||
self.model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.active = False
|
||||
|
||||
# Only restore original forward if this is the last active context
|
||||
if (
|
||||
SequenceParallelContext._active_contexts
|
||||
and SequenceParallelContext._active_contexts[-1] == self
|
||||
):
|
||||
SequenceParallelContext._active_contexts.pop()
|
||||
|
||||
# Restore original forward method
|
||||
if self._original_module_forward is not None:
|
||||
nn.Module.forward = self._original_module_forward
|
||||
self._original_module_forward = None
|
||||
|
||||
# Remove any hooks we added
|
||||
for hook in self._hooks:
|
||||
hook.remove()
|
||||
self._hooks = []
|
||||
|
||||
def _make_sequence_parallel_forward(self, original_forward):
|
||||
"""Create a wrapped forward method that applies sequence parallelism."""
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def sequence_parallel_forward(module_self, *args, **kwargs):
|
||||
# Skip sequence parallelism if not active or degree is 1
|
||||
if not self.active or self.sequence_parallel_degree <= 1:
|
||||
return original_forward(module_self, *args, **kwargs)
|
||||
|
||||
# Convert args to kwargs if needed
|
||||
if args:
|
||||
# Try to map positional args to kwargs based on the forward method signature
|
||||
import inspect
|
||||
|
||||
try:
|
||||
signature = inspect.signature(original_forward)
|
||||
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(param_names):
|
||||
kwargs[param_names[i]] = arg
|
||||
else:
|
||||
# If we can't map all args, fall back to original forward
|
||||
return original_forward(module_self, *args, **kwargs)
|
||||
except (ValueError, TypeError):
|
||||
# If we can't get the signature, just use the original forward
|
||||
return original_forward(module_self, *args, **kwargs)
|
||||
|
||||
# Apply sequence parallelism to the inputs
|
||||
kwargs = self.apply_sequence_parallelism(kwargs)
|
||||
|
||||
# Call the original forward with modified kwargs
|
||||
return original_forward(module_self, **kwargs)
|
||||
|
||||
return sequence_parallel_forward
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
def apply_sequence_parallelism(
|
||||
self, batch: Dict[str, torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
self, batch: dict[str, torch.Tensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
@@ -136,198 +123,45 @@ class SequenceParallelContext:
|
||||
if self.sequence_parallel_degree <= 1 or not self.active:
|
||||
return batch
|
||||
|
||||
# Make a copy of the batch to avoid modifying the original
|
||||
new_batch = dict(batch)
|
||||
|
||||
# Get total sequence length from input_ids or inputs_embeds
|
||||
if "input_ids" in new_batch:
|
||||
total_seq_len = new_batch["input_ids"].size(1)
|
||||
elif "inputs_embeds" in new_batch:
|
||||
total_seq_len = new_batch["inputs_embeds"].size(1)
|
||||
else:
|
||||
# If we can't determine sequence length, return the batch as is
|
||||
return new_batch
|
||||
|
||||
# Update ring attention params if needed
|
||||
if new_batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=new_batch["position_ids"])
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
for key in new_batch:
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
for key in batch:
|
||||
if (
|
||||
key in new_batch
|
||||
and isinstance(new_batch[key], torch.Tensor)
|
||||
and new_batch[key].dim() > 1
|
||||
and new_batch[key].size(1) == total_seq_len
|
||||
key in batch
|
||||
and isinstance(batch[key], torch.Tensor)
|
||||
and batch[key].dim() > 1
|
||||
and batch[key].size(1) == total_seq_len
|
||||
):
|
||||
|
||||
if self.ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
]:
|
||||
new_batch[key] = (
|
||||
new_batch[key]
|
||||
# Split in sequential fashion and grab this rank's chunk
|
||||
batch[key] = (
|
||||
batch[key]
|
||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||
.contiguous()
|
||||
)
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
chunks = new_batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||
|
||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||
selected_chunks = [
|
||||
chunks[self.local_rank],
|
||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||
]
|
||||
new_batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||
# Split into striped data and stack
|
||||
tensor = torch.stack(
|
||||
new_batch[key].split(self.local_world_size, dim=1),
|
||||
batch[key].split(self.local_world_size, dim=1),
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
new_batch[key] = tensor[:, self.local_rank].contiguous()
|
||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||
|
||||
return new_batch
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def sequence_parallel(
|
||||
sequence_parallel_degree: int = 1,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
ring_attn_func: Optional[RingAttnFunc] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
buffer_seq_dims: Optional[List[int]] = None,
|
||||
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Context manager for sequence parallelism.
|
||||
|
||||
This context manager will apply sequence parallelism to model inputs
|
||||
for all forward passes within its scope.
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: The degree of sequence parallelism.
|
||||
process_group: The process group to use for communication. Default is the world group.
|
||||
ring_attn_func: The ring attention function to use.
|
||||
buffers: Optional list of buffers to shard (e.g., input tensors, position embeddings).
|
||||
buffer_seq_dims: Sequence dimensions for each buffer to shard.
|
||||
no_restore_buffers: Optional set of buffers that don't need to be restored.
|
||||
|
||||
Yields:
|
||||
The sequence parallel context.
|
||||
"""
|
||||
buffers = [] if buffers is None else buffers
|
||||
buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
|
||||
no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
|
||||
|
||||
if len(buffers) != len(buffer_seq_dims):
|
||||
raise ValueError(
|
||||
"`buffer_seq_dims` must have the same number of elements as `buffers`."
|
||||
)
|
||||
|
||||
# Save original buffer states
|
||||
original_buffers = []
|
||||
for buffer in buffers:
|
||||
if buffer in no_restore_buffers:
|
||||
original_buffers.append(None)
|
||||
else:
|
||||
original_buffers.append(buffer.clone())
|
||||
|
||||
# Create context
|
||||
context = SequenceParallelContext(
|
||||
sequence_parallel_degree=sequence_parallel_degree,
|
||||
process_group=process_group,
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
# Apply sequence parallelism to buffers if provided
|
||||
if buffers and buffer_seq_dims:
|
||||
for i, (buffer, dim) in enumerate(zip(buffers, buffer_seq_dims)):
|
||||
if context.sequence_parallel_degree > 1:
|
||||
# Get local shard
|
||||
sharded_tensor = context.apply_sequence_parallelism(
|
||||
{"tensor": buffer.unsqueeze(0)}
|
||||
)["tensor"].squeeze(0)
|
||||
|
||||
# Resize and copy in-place
|
||||
buffer.resize_(sharded_tensor.shape)
|
||||
buffer.copy_(sharded_tensor)
|
||||
|
||||
try:
|
||||
# Enter the context
|
||||
with context:
|
||||
yield context
|
||||
finally:
|
||||
# Restore original buffer states
|
||||
for buffer, original in zip(buffers, original_buffers):
|
||||
if original is not None:
|
||||
buffer.resize_(original.shape)
|
||||
buffer.copy_(original)
|
||||
|
||||
|
||||
def enable_sequence_parallel_for_module(
|
||||
module: nn.Module,
|
||||
sequence_parallel_degree: int = 1,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
ring_attn_func: Optional[RingAttnFunc] = None,
|
||||
):
|
||||
"""
|
||||
Enable sequence parallelism for a specific module.
|
||||
|
||||
This function wraps the module's forward method to automatically apply
|
||||
sequence parallelism without using a context manager.
|
||||
|
||||
Args:
|
||||
module: The module to enable sequence parallelism for.
|
||||
sequence_parallel_degree: The degree of sequence parallelism.
|
||||
process_group: The process group to use for communication.
|
||||
ring_attn_func: The ring attention function to use.
|
||||
|
||||
Returns:
|
||||
The module with sequence parallelism enabled.
|
||||
"""
|
||||
# Create a context for this module
|
||||
context = SequenceParallelContext(
|
||||
sequence_parallel_degree=sequence_parallel_degree,
|
||||
process_group=process_group,
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
# Save the original forward method
|
||||
original_forward = module.forward
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def sequence_parallel_forward(*args, **kwargs):
|
||||
# Activate the context
|
||||
context.active = True
|
||||
|
||||
# Convert args to kwargs if needed
|
||||
if args:
|
||||
import inspect
|
||||
|
||||
try:
|
||||
signature = inspect.signature(original_forward)
|
||||
param_names = list(signature.parameters.keys())
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(param_names):
|
||||
kwargs[param_names[i]] = arg
|
||||
else:
|
||||
return original_forward(*args, **kwargs)
|
||||
except (ValueError, TypeError):
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
# Apply sequence parallelism to inputs
|
||||
kwargs = context.apply_sequence_parallelism(kwargs)
|
||||
|
||||
# Call original forward with modified inputs
|
||||
result = original_forward(**kwargs)
|
||||
|
||||
# Deactivate the context
|
||||
context.active = False
|
||||
|
||||
return result
|
||||
|
||||
# Replace the forward method
|
||||
module.forward = sequence_parallel_forward
|
||||
|
||||
return module
|
||||
return batch
|
||||
|
||||
@@ -197,6 +197,7 @@ def execute_training(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
elif cfg.sequence_parallel_degree > 1:
|
||||
with SequenceParallelContext(
|
||||
model=trainer.model,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
):
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import (
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DatasetConfig,
|
||||
DPODataset,
|
||||
@@ -1149,22 +1150,17 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="after")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
value = 1
|
||||
|
||||
if value > 1:
|
||||
if not info.data.get("flash_attention"):
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if (
|
||||
info.data.get("sample_packing")
|
||||
and not info.data["micro_batch_size"] == 1
|
||||
):
|
||||
if self.sample_packing and not self.micro_batch_size:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
@@ -1182,44 +1178,44 @@ class AxolotlInputConfig(
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
||||
"differ slightly to the non-SP losses due to transformers Trainer "
|
||||
"implementation details. Please see "
|
||||
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
if is_main_process():
|
||||
LOG.warning(
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
"for more details."
|
||||
)
|
||||
|
||||
return value
|
||||
return self
|
||||
|
||||
@field_validator("ring_attn_func", mode="after")
|
||||
@classmethod
|
||||
def check_ring_attn_func(cls, value, info):
|
||||
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
||||
return value
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if self.sequence_parallel_degree == 1:
|
||||
return self
|
||||
|
||||
# Your validation logic for ring_attn_func
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
if value is not None:
|
||||
# Set the ring attention function if passed in config
|
||||
if self.ring_attn_func is not None:
|
||||
valid_funcs = list(RingAttnFunc)
|
||||
if value in valid_funcs:
|
||||
value = RingAttnFunc(value)
|
||||
if self.ring_attn_func in valid_funcs:
|
||||
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
||||
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
|
||||
)
|
||||
else:
|
||||
# Default ring attention function selection
|
||||
sample_packing = info.data.get("sample_packing")
|
||||
value = (
|
||||
sample_packing = getattr(self, "sample_packing", False)
|
||||
self.ring_attn_func = (
|
||||
RingAttnFunc.VARLEN_LLAMA3
|
||||
if sample_packing
|
||||
else RingAttnFunc.BATCH_RING
|
||||
)
|
||||
|
||||
return value
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user