reorg
This commit is contained in:
@@ -1,12 +1,21 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
"""
|
||||||
|
Module for Axolotl trainer sequence parallelism mixin and training context manager
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
|
RingAttnFunc,
|
||||||
|
get_ring_attn_group,
|
||||||
|
update_ring_attn_params,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -87,3 +96,106 @@ class SequenceParallelMixin:
|
|||||||
return self._create_sequence_parallel_sampler(
|
return self._create_sequence_parallel_sampler(
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelContext:
|
||||||
|
"""
|
||||||
|
Context manager for sequence parallelism operations.
|
||||||
|
|
||||||
|
This class provides a context that will automatically apply sequence parallelism
|
||||||
|
during model forward passes using a pre-forward hook.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
sequence_parallel_degree: int,
|
||||||
|
ring_attn_func: RingAttnFunc,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
|
self.ring_attn_func = ring_attn_func
|
||||||
|
self.process_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
# Initialize sequence parallel group details
|
||||||
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
|
|
||||||
|
# Will store hook handles for removal
|
||||||
|
self.hook_handle: RemovableHandle | None = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# Define a forward pre-hook to apply sequence parallelism with kwargs support
|
||||||
|
def sequence_parallel_pre_hook(
|
||||||
|
module, args, kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# Apply sequence parallelism to kwargs
|
||||||
|
kwargs = self.apply_sequence_parallelism(kwargs)
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
# Register the pre-forward hook on the model
|
||||||
|
self.hook_handle = self.model.register_forward_pre_hook(
|
||||||
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
# Remove the forward pre-hook
|
||||||
|
self.hook_handle.remove()
|
||||||
|
self.hook_handle = None
|
||||||
|
|
||||||
|
def apply_sequence_parallelism(
|
||||||
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
|
# Update ring attention params if needed
|
||||||
|
if batch.get("position_ids") is not None:
|
||||||
|
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||||
|
|
||||||
|
# Slice batch for sequence parallel processing
|
||||||
|
total_seq_len = batch["input_ids"].size(1)
|
||||||
|
for key in batch:
|
||||||
|
if (
|
||||||
|
key in batch
|
||||||
|
and isinstance(batch[key], torch.Tensor)
|
||||||
|
and batch[key].dim() > 1
|
||||||
|
and batch[key].size(1) == total_seq_len
|
||||||
|
):
|
||||||
|
|
||||||
|
if self.ring_attn_func in [
|
||||||
|
RingAttnFunc.VARLEN_LLAMA3,
|
||||||
|
RingAttnFunc.BATCH_RING,
|
||||||
|
]:
|
||||||
|
# Split in sequential fashion and grab this rank's chunk
|
||||||
|
batch[key] = (
|
||||||
|
batch[key]
|
||||||
|
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||||
|
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||||
|
|
||||||
|
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||||
|
selected_chunks = [
|
||||||
|
chunks[self.local_rank],
|
||||||
|
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||||
|
]
|
||||||
|
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||||
|
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||||
|
# Split into striped data and stack
|
||||||
|
tensor = torch.stack(
|
||||||
|
batch[key].split(self.local_world_size, dim=1),
|
||||||
|
dim=1,
|
||||||
|
).transpose(1, 2)
|
||||||
|
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
"""Module for definition of sequence parallel context manager"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import nn
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
|
||||||
RingAttnFunc,
|
|
||||||
get_ring_attn_group,
|
|
||||||
update_ring_attn_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelContext:
|
|
||||||
"""
|
|
||||||
Context manager for sequence parallelism operations.
|
|
||||||
|
|
||||||
This class provides a context that will automatically apply sequence parallelism
|
|
||||||
during model forward passes using a pre-forward hook.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
sequence_parallel_degree: int,
|
|
||||||
ring_attn_func: RingAttnFunc,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
|
||||||
self.ring_attn_func = ring_attn_func
|
|
||||||
self.process_group = get_ring_attn_group()
|
|
||||||
|
|
||||||
# Initialize sequence parallel group details
|
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
|
||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
|
||||||
|
|
||||||
# Will store hook handles for removal
|
|
||||||
self.hook_handle: RemovableHandle | None = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
# Define a forward pre-hook to apply sequence parallelism with kwargs support
|
|
||||||
def sequence_parallel_pre_hook(
|
|
||||||
module, args, kwargs
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
# Apply sequence parallelism to kwargs
|
|
||||||
kwargs = self.apply_sequence_parallelism(kwargs)
|
|
||||||
return args, kwargs
|
|
||||||
|
|
||||||
# Register the pre-forward hook on the model
|
|
||||||
self.hook_handle = self.model.register_forward_pre_hook(
|
|
||||||
sequence_parallel_pre_hook, with_kwargs=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
# Remove the forward pre-hook
|
|
||||||
self.hook_handle.remove()
|
|
||||||
self.hook_handle = None
|
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
|
||||||
self, batch: dict[str, torch.Tensor]
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Apply sequence parallelism slicing to a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sliced batch dictionary.
|
|
||||||
"""
|
|
||||||
# Update ring attention params if needed
|
|
||||||
if batch.get("position_ids") is not None:
|
|
||||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
|
||||||
|
|
||||||
# Slice batch for sequence parallel processing
|
|
||||||
total_seq_len = batch["input_ids"].size(1)
|
|
||||||
for key in batch:
|
|
||||||
if (
|
|
||||||
key in batch
|
|
||||||
and isinstance(batch[key], torch.Tensor)
|
|
||||||
and batch[key].dim() > 1
|
|
||||||
and batch[key].size(1) == total_seq_len
|
|
||||||
):
|
|
||||||
|
|
||||||
if self.ring_attn_func in [
|
|
||||||
RingAttnFunc.VARLEN_LLAMA3,
|
|
||||||
RingAttnFunc.BATCH_RING,
|
|
||||||
]:
|
|
||||||
# Split in sequential fashion and grab this rank's chunk
|
|
||||||
batch[key] = (
|
|
||||||
batch[key]
|
|
||||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
|
||||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
|
||||||
|
|
||||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
|
||||||
selected_chunks = [
|
|
||||||
chunks[self.local_rank],
|
|
||||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
|
||||||
]
|
|
||||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
|
||||||
# Split into striped data and stack
|
|
||||||
tensor = torch.stack(
|
|
||||||
batch[key].split(self.local_world_size, dim=1),
|
|
||||||
dim=1,
|
|
||||||
).transpose(1, 2)
|
|
||||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
|
||||||
|
|
||||||
return batch
|
|
||||||
@@ -26,7 +26,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
|||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.core.trainers.sp import SequenceParallelContext
|
from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
|
|||||||
Reference in New Issue
Block a user