This commit is contained in:
Dan Saunders
2025-04-24 00:11:37 +00:00
parent 5816433121
commit d92ac7a41d
3 changed files with 115 additions and 122 deletions

View File

@@ -1,12 +1,21 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__)
@@ -87,3 +96,106 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContext:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handle: RemovableHandle | None = None
def __enter__(self):
# Define a forward pre-hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(
module, args, kwargs
): # pylint: disable=unused-argument
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(kwargs)
return args, kwargs
# Register the pre-forward hook on the model
self.hook_handle = self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove the forward pre-hook
self.hook_handle.remove()
self.hook_handle = None
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch

View File

@@ -1,119 +0,0 @@
"""Module for definition of sequence parallel context manager"""
import logging
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn.patch import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
logger = logging.getLogger(__name__)
class SequenceParallelContext:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handle: RemovableHandle | None = None
def __enter__(self):
# Define a forward pre-hook to apply sequence parallelism with kwargs support
def sequence_parallel_pre_hook(
module, args, kwargs
): # pylint: disable=unused-argument
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(kwargs)
return args, kwargs
# Register the pre-forward hook on the model
self.hook_handle = self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove the forward pre-hook
self.hook_handle.remove()
self.hook_handle = None
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch

View File

@@ -26,7 +26,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.sp import SequenceParallelContext
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed