updates
This commit is contained in:
@@ -371,13 +371,20 @@ class AxolotlTrainer(
|
|||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
loss = super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# This is needed due to details of our sequence parallel implementation; the HF
|
||||||
|
# trainer averages the loss over the full sequence length depite our splitting
|
||||||
|
# the data along the sequence dimension.
|
||||||
|
loss *= self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||||
concatenated_batch = {}
|
concatenated_batch = {}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import contextlib
|
"""Module for definition of sequence parallel context manager"""
|
||||||
import functools
|
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
from torch import nn
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
from axolotl.monkeypatch.attention.ring_attn.patch import (
|
||||||
@@ -22,17 +22,16 @@ class SequenceParallelContext:
|
|||||||
Context manager for sequence parallelism operations.
|
Context manager for sequence parallelism operations.
|
||||||
|
|
||||||
This class provides a context that will automatically apply sequence parallelism
|
This class provides a context that will automatically apply sequence parallelism
|
||||||
during model forward passes.
|
during model forward passes using pre-forward hooks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Keep track of active contexts to support nested contexts
|
|
||||||
_active_contexts = []
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model: nn.Module,
|
||||||
sequence_parallel_degree: int,
|
sequence_parallel_degree: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
):
|
):
|
||||||
|
self.model = model
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.process_group = get_ring_attn_group()
|
self.process_group = get_ring_attn_group()
|
||||||
@@ -42,9 +41,8 @@ class SequenceParallelContext:
|
|||||||
self.local_world_size = 1
|
self.local_world_size = 1
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
# Will store original methods for restoration
|
# Will store hook handles for removal
|
||||||
self._original_module_forward = None
|
self.hook_handles: list[RemovableHandle] = []
|
||||||
self._hooks: List[RemovableHandle] = []
|
|
||||||
|
|
||||||
if self.sequence_parallel_degree > 1:
|
if self.sequence_parallel_degree > 1:
|
||||||
if self.process_group is None:
|
if self.process_group is None:
|
||||||
@@ -55,75 +53,64 @@ class SequenceParallelContext:
|
|||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.active = True
|
self.active = True
|
||||||
SequenceParallelContext._active_contexts.append(self)
|
|
||||||
|
|
||||||
# Store the original forward method
|
# Define a pre-forward hook to apply sequence parallelism with kwargs support
|
||||||
if self._original_module_forward is None:
|
def sequence_parallel_pre_hook(module, args, kwargs):
|
||||||
self._original_module_forward = nn.Module.forward
|
if not self.active or self.sequence_parallel_degree <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
# Replace nn.Module.forward with our sequence parallel version
|
# Apply sequence parallelism to kwargs
|
||||||
nn.Module.forward = self._make_sequence_parallel_forward(nn.Module.forward)
|
if kwargs:
|
||||||
|
transformed_kwargs = self.apply_sequence_parallelism(kwargs)
|
||||||
|
return args, transformed_kwargs
|
||||||
|
|
||||||
|
# If no kwargs but we have args, try to convert them to kwargs
|
||||||
|
if args and not kwargs:
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(module.forward)
|
||||||
|
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
||||||
|
|
||||||
|
# Create kwargs from args
|
||||||
|
new_kwargs = {}
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i < len(param_names):
|
||||||
|
new_kwargs[param_names[i]] = arg
|
||||||
|
else:
|
||||||
|
# If we can't map all args, don't transform
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply sequence parallelism to the new kwargs
|
||||||
|
transformed_kwargs = self.apply_sequence_parallelism(new_kwargs)
|
||||||
|
|
||||||
|
# Return empty args and the transformed kwargs
|
||||||
|
return (), transformed_kwargs
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# If conversion fails, don't transform
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If no args and no kwargs, nothing to transform
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Register the pre-forward hook on the model
|
||||||
|
self.hook_handles.append(
|
||||||
|
self.model.register_forward_pre_hook(
|
||||||
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
# Only restore original forward if this is the last active context
|
# Remove all hooks
|
||||||
if (
|
for handle in self.hook_handles:
|
||||||
SequenceParallelContext._active_contexts
|
handle.remove()
|
||||||
and SequenceParallelContext._active_contexts[-1] == self
|
self.hook_handles = []
|
||||||
):
|
|
||||||
SequenceParallelContext._active_contexts.pop()
|
|
||||||
|
|
||||||
# Restore original forward method
|
|
||||||
if self._original_module_forward is not None:
|
|
||||||
nn.Module.forward = self._original_module_forward
|
|
||||||
self._original_module_forward = None
|
|
||||||
|
|
||||||
# Remove any hooks we added
|
|
||||||
for hook in self._hooks:
|
|
||||||
hook.remove()
|
|
||||||
self._hooks = []
|
|
||||||
|
|
||||||
def _make_sequence_parallel_forward(self, original_forward):
|
|
||||||
"""Create a wrapped forward method that applies sequence parallelism."""
|
|
||||||
|
|
||||||
@functools.wraps(original_forward)
|
|
||||||
def sequence_parallel_forward(module_self, *args, **kwargs):
|
|
||||||
# Skip sequence parallelism if not active or degree is 1
|
|
||||||
if not self.active or self.sequence_parallel_degree <= 1:
|
|
||||||
return original_forward(module_self, *args, **kwargs)
|
|
||||||
|
|
||||||
# Convert args to kwargs if needed
|
|
||||||
if args:
|
|
||||||
# Try to map positional args to kwargs based on the forward method signature
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
try:
|
|
||||||
signature = inspect.signature(original_forward)
|
|
||||||
param_names = list(signature.parameters.keys())[1:] # Skip 'self'
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i < len(param_names):
|
|
||||||
kwargs[param_names[i]] = arg
|
|
||||||
else:
|
|
||||||
# If we can't map all args, fall back to original forward
|
|
||||||
return original_forward(module_self, *args, **kwargs)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# If we can't get the signature, just use the original forward
|
|
||||||
return original_forward(module_self, *args, **kwargs)
|
|
||||||
|
|
||||||
# Apply sequence parallelism to the inputs
|
|
||||||
kwargs = self.apply_sequence_parallelism(kwargs)
|
|
||||||
|
|
||||||
# Call the original forward with modified kwargs
|
|
||||||
return original_forward(module_self, **kwargs)
|
|
||||||
|
|
||||||
return sequence_parallel_forward
|
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
def apply_sequence_parallelism(
|
||||||
self, batch: Dict[str, torch.Tensor]
|
self, batch: dict[str, torch.Tensor]
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Apply sequence parallelism slicing to a batch.
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
@@ -136,198 +123,45 @@ class SequenceParallelContext:
|
|||||||
if self.sequence_parallel_degree <= 1 or not self.active:
|
if self.sequence_parallel_degree <= 1 or not self.active:
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
# Make a copy of the batch to avoid modifying the original
|
|
||||||
new_batch = dict(batch)
|
|
||||||
|
|
||||||
# Get total sequence length from input_ids or inputs_embeds
|
|
||||||
if "input_ids" in new_batch:
|
|
||||||
total_seq_len = new_batch["input_ids"].size(1)
|
|
||||||
elif "inputs_embeds" in new_batch:
|
|
||||||
total_seq_len = new_batch["inputs_embeds"].size(1)
|
|
||||||
else:
|
|
||||||
# If we can't determine sequence length, return the batch as is
|
|
||||||
return new_batch
|
|
||||||
|
|
||||||
# Update ring attention params if needed
|
# Update ring attention params if needed
|
||||||
if new_batch.get("position_ids") is not None:
|
if batch.get("position_ids") is not None:
|
||||||
update_ring_attn_params(position_ids=new_batch["position_ids"])
|
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||||
|
|
||||||
# Slice batch for sequence parallel processing
|
# Slice batch for sequence parallel processing
|
||||||
for key in new_batch:
|
total_seq_len = batch["input_ids"].size(1)
|
||||||
|
for key in batch:
|
||||||
if (
|
if (
|
||||||
key in new_batch
|
key in batch
|
||||||
and isinstance(new_batch[key], torch.Tensor)
|
and isinstance(batch[key], torch.Tensor)
|
||||||
and new_batch[key].dim() > 1
|
and batch[key].dim() > 1
|
||||||
and new_batch[key].size(1) == total_seq_len
|
and batch[key].size(1) == total_seq_len
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.ring_attn_func in [
|
if self.ring_attn_func in [
|
||||||
RingAttnFunc.VARLEN_LLAMA3,
|
RingAttnFunc.VARLEN_LLAMA3,
|
||||||
RingAttnFunc.BATCH_RING,
|
RingAttnFunc.BATCH_RING,
|
||||||
]:
|
]:
|
||||||
new_batch[key] = (
|
# Split in sequential fashion and grab this rank's chunk
|
||||||
new_batch[key]
|
batch[key] = (
|
||||||
|
batch[key]
|
||||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||||
.contiguous()
|
.contiguous()
|
||||||
)
|
)
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||||
chunks = new_batch[key].chunk(2 * self.local_world_size, dim=1)
|
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||||
|
|
||||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||||
selected_chunks = [
|
selected_chunks = [
|
||||||
chunks[self.local_rank],
|
chunks[self.local_rank],
|
||||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||||
]
|
]
|
||||||
new_batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||||
# Split into striped data and stack
|
# Split into striped data and stack
|
||||||
tensor = torch.stack(
|
tensor = torch.stack(
|
||||||
new_batch[key].split(self.local_world_size, dim=1),
|
batch[key].split(self.local_world_size, dim=1),
|
||||||
dim=1,
|
dim=1,
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
new_batch[key] = tensor[:, self.local_rank].contiguous()
|
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||||
|
|
||||||
return new_batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def sequence_parallel(
|
|
||||||
sequence_parallel_degree: int = 1,
|
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
|
||||||
ring_attn_func: Optional[RingAttnFunc] = None,
|
|
||||||
buffers: Optional[List[torch.Tensor]] = None,
|
|
||||||
buffer_seq_dims: Optional[List[int]] = None,
|
|
||||||
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Context manager for sequence parallelism.
|
|
||||||
|
|
||||||
This context manager will apply sequence parallelism to model inputs
|
|
||||||
for all forward passes within its scope.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: The degree of sequence parallelism.
|
|
||||||
process_group: The process group to use for communication. Default is the world group.
|
|
||||||
ring_attn_func: The ring attention function to use.
|
|
||||||
buffers: Optional list of buffers to shard (e.g., input tensors, position embeddings).
|
|
||||||
buffer_seq_dims: Sequence dimensions for each buffer to shard.
|
|
||||||
no_restore_buffers: Optional set of buffers that don't need to be restored.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
The sequence parallel context.
|
|
||||||
"""
|
|
||||||
buffers = [] if buffers is None else buffers
|
|
||||||
buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims
|
|
||||||
no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers
|
|
||||||
|
|
||||||
if len(buffers) != len(buffer_seq_dims):
|
|
||||||
raise ValueError(
|
|
||||||
"`buffer_seq_dims` must have the same number of elements as `buffers`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save original buffer states
|
|
||||||
original_buffers = []
|
|
||||||
for buffer in buffers:
|
|
||||||
if buffer in no_restore_buffers:
|
|
||||||
original_buffers.append(None)
|
|
||||||
else:
|
|
||||||
original_buffers.append(buffer.clone())
|
|
||||||
|
|
||||||
# Create context
|
|
||||||
context = SequenceParallelContext(
|
|
||||||
sequence_parallel_degree=sequence_parallel_degree,
|
|
||||||
process_group=process_group,
|
|
||||||
ring_attn_func=ring_attn_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply sequence parallelism to buffers if provided
|
|
||||||
if buffers and buffer_seq_dims:
|
|
||||||
for i, (buffer, dim) in enumerate(zip(buffers, buffer_seq_dims)):
|
|
||||||
if context.sequence_parallel_degree > 1:
|
|
||||||
# Get local shard
|
|
||||||
sharded_tensor = context.apply_sequence_parallelism(
|
|
||||||
{"tensor": buffer.unsqueeze(0)}
|
|
||||||
)["tensor"].squeeze(0)
|
|
||||||
|
|
||||||
# Resize and copy in-place
|
|
||||||
buffer.resize_(sharded_tensor.shape)
|
|
||||||
buffer.copy_(sharded_tensor)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Enter the context
|
|
||||||
with context:
|
|
||||||
yield context
|
|
||||||
finally:
|
|
||||||
# Restore original buffer states
|
|
||||||
for buffer, original in zip(buffers, original_buffers):
|
|
||||||
if original is not None:
|
|
||||||
buffer.resize_(original.shape)
|
|
||||||
buffer.copy_(original)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_sequence_parallel_for_module(
|
|
||||||
module: nn.Module,
|
|
||||||
sequence_parallel_degree: int = 1,
|
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
|
||||||
ring_attn_func: Optional[RingAttnFunc] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Enable sequence parallelism for a specific module.
|
|
||||||
|
|
||||||
This function wraps the module's forward method to automatically apply
|
|
||||||
sequence parallelism without using a context manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module: The module to enable sequence parallelism for.
|
|
||||||
sequence_parallel_degree: The degree of sequence parallelism.
|
|
||||||
process_group: The process group to use for communication.
|
|
||||||
ring_attn_func: The ring attention function to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The module with sequence parallelism enabled.
|
|
||||||
"""
|
|
||||||
# Create a context for this module
|
|
||||||
context = SequenceParallelContext(
|
|
||||||
sequence_parallel_degree=sequence_parallel_degree,
|
|
||||||
process_group=process_group,
|
|
||||||
ring_attn_func=ring_attn_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save the original forward method
|
|
||||||
original_forward = module.forward
|
|
||||||
|
|
||||||
@functools.wraps(original_forward)
|
|
||||||
def sequence_parallel_forward(*args, **kwargs):
|
|
||||||
# Activate the context
|
|
||||||
context.active = True
|
|
||||||
|
|
||||||
# Convert args to kwargs if needed
|
|
||||||
if args:
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
try:
|
|
||||||
signature = inspect.signature(original_forward)
|
|
||||||
param_names = list(signature.parameters.keys())
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i < len(param_names):
|
|
||||||
kwargs[param_names[i]] = arg
|
|
||||||
else:
|
|
||||||
return original_forward(*args, **kwargs)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return original_forward(*args, **kwargs)
|
|
||||||
|
|
||||||
# Apply sequence parallelism to inputs
|
|
||||||
kwargs = context.apply_sequence_parallelism(kwargs)
|
|
||||||
|
|
||||||
# Call original forward with modified inputs
|
|
||||||
result = original_forward(**kwargs)
|
|
||||||
|
|
||||||
# Deactivate the context
|
|
||||||
context.active = False
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Replace the forward method
|
|
||||||
module.forward = sequence_parallel_forward
|
|
||||||
|
|
||||||
return module
|
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ def execute_training(
|
|||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
elif cfg.sequence_parallel_degree > 1:
|
elif cfg.sequence_parallel_degree > 1:
|
||||||
with SequenceParallelContext(
|
with SequenceParallelContext(
|
||||||
|
model=trainer.model,
|
||||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from pydantic import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
@@ -1149,22 +1150,17 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@field_validator("sequence_parallel_degree", mode="after")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def check_sequence_parallel_degree(self):
|
||||||
def check_sequence_parallel_degree(cls, value, info):
|
if not self.sequence_parallel_degree:
|
||||||
if not value:
|
self.sequence_parallel_degree = 1
|
||||||
value = 1
|
elif self.sequence_parallel_degree > 1:
|
||||||
|
if not self.flash_attention:
|
||||||
if value > 1:
|
|
||||||
if not info.data.get("flash_attention"):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if self.sample_packing and not self.micro_batch_size:
|
||||||
info.data.get("sample_packing")
|
|
||||||
and not info.data["micro_batch_size"] == 1
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||||
"due to a `ring-flash-attn` requirement"
|
"due to a `ring-flash-attn` requirement"
|
||||||
@@ -1182,44 +1178,44 @@ class AxolotlInputConfig(
|
|||||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||||
# according to the proportion of non-padding tokens per rank.
|
# according to the proportion of non-padding tokens per rank.
|
||||||
LOG.warning(
|
if is_main_process():
|
||||||
"Sequence parallelism (SP) is enabled with "
|
LOG.warning(
|
||||||
f"sequence_parallel_degree={value}. Please note that logged losses may "
|
"Sequence parallelism (SP) is enabled with "
|
||||||
"differ slightly to the non-SP losses due to transformers Trainer "
|
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||||
"implementation details. Please see "
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
"losses due to transformers Trainer implementation details. "
|
||||||
"for more details."
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
)
|
"for more details."
|
||||||
|
)
|
||||||
|
|
||||||
return value
|
return self
|
||||||
|
|
||||||
@field_validator("ring_attn_func", mode="after")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def validate_ring_attn_func(self):
|
||||||
def check_ring_attn_func(cls, value, info):
|
if self.sequence_parallel_degree == 1:
|
||||||
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
return self
|
||||||
return value
|
|
||||||
|
|
||||||
|
# Your validation logic for ring_attn_func
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
if value is not None:
|
if self.ring_attn_func is not None:
|
||||||
# Set the ring attention function if passed in config
|
|
||||||
valid_funcs = list(RingAttnFunc)
|
valid_funcs = list(RingAttnFunc)
|
||||||
if value in valid_funcs:
|
if self.ring_attn_func in valid_funcs:
|
||||||
value = RingAttnFunc(value)
|
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Default ring attention function selection
|
# Default ring attention function selection
|
||||||
sample_packing = info.data.get("sample_packing")
|
sample_packing = getattr(self, "sample_packing", False)
|
||||||
value = (
|
self.ring_attn_func = (
|
||||||
RingAttnFunc.VARLEN_LLAMA3
|
RingAttnFunc.VARLEN_LLAMA3
|
||||||
if sample_packing
|
if sample_packing
|
||||||
else RingAttnFunc.BATCH_RING
|
else RingAttnFunc.BATCH_RING
|
||||||
)
|
)
|
||||||
|
|
||||||
return value
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user