batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * removing pad_to_sequence_len=False for now * fix * updating docs to include batch SP * review comments * fixes for batch API funcs, simplify * fixes * fix * updates * add batch_zigzag smoke test
This commit is contained in:
@@ -776,6 +776,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -933,6 +934,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
kwargs["ring_attn_func"] = training_args.ring_attn_func
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
|
||||
@@ -9,6 +9,8 @@ from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
ring_attn_func: Optional[RingAttnFunc] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The ring-flash-attn function to use in sequence parallelism"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
|
||||
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
12
src/axolotl/monkeypatch/attention/ring_attn/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Init for ring attention monkeypatch module"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .patch import (
|
||||
RingAttnFunc,
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
192
src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
HuggingFace flash attention adapter for basic ring attention (batch API).
|
||||
|
||||
Inspired by
|
||||
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
|
||||
Our implementation closely follows the structure of that module, but we've minified it
|
||||
somewhat to support only the latest versions of transformers.
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access,cyclic-import
|
||||
|
||||
import os
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
import transformers.modeling_flash_attention_utils
|
||||
from ring_flash_attn import (
|
||||
ring_flash_attn_func,
|
||||
stripe_flash_attn_func,
|
||||
zigzag_ring_flash_attn_func,
|
||||
)
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size,
|
||||
is_flash_attn_greater_or_equal,
|
||||
)
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
RING_ATTN_FUNC_MAPPING = {
|
||||
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
|
||||
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
|
||||
}
|
||||
|
||||
|
||||
def create_flash_attn_forward(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
) -> Callable:
|
||||
"""
|
||||
Create a ring flash attention forward function compatible with HuggingFace's
|
||||
interface.
|
||||
|
||||
Args:
|
||||
process_group: A PyTorch distributed process group.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
|
||||
Returns:
|
||||
A function that implements the ring flash attention forward pass with the
|
||||
signature expected by HuggingFace Transformers.
|
||||
"""
|
||||
|
||||
# transformers 4.48+
|
||||
# pylint: disable=unused-argument
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
dropout: float = 0.0,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
softmax_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: float | None = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: torch.LongTensor | None = None,
|
||||
cu_seq_lens_k: torch.LongTensor | None = None,
|
||||
max_length_q: int | None = None,
|
||||
max_length_k: int | None = None,
|
||||
target_dtype: torch.dtype | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Ring Flash Attention.
|
||||
|
||||
Args:
|
||||
query_states: Tensor containing the query vectors.
|
||||
key_states: Tensor containing the key vectors.
|
||||
value_states: Tensor containing the value vectors.
|
||||
attention_mask: Not used in this implementation.
|
||||
query_length: Integer representing the length of the query sequence.
|
||||
is_causal: Boolean indicating whether to apply a causal mask to the attention.
|
||||
dropout: Float representing the dropout probability. Default is 0.0.
|
||||
position_ids: Not used in this implementation.
|
||||
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
|
||||
sliding_window: Optional integer defining the size of the sliding attention window.
|
||||
Default is None.
|
||||
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
|
||||
Default is False.
|
||||
softcap: Not used in this implementation.
|
||||
deterministic: Optional boolean to enforce deterministic computation. Default is None.
|
||||
cu_seq_lens_q: Not used in this implementation.
|
||||
cu_seq_lens_k: Not used in this implementation.
|
||||
max_length_q: Not used in this implementation.
|
||||
max_length_k: Not used in this implementation.
|
||||
target_dtype: Not used in this implementation.
|
||||
**kwargs: Additional keyword arguments. Not used in this implementation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output of the attention mechanism, with shape
|
||||
`[batch_size, query_length, num_heads, head_dim]`.
|
||||
"""
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
causal = is_causal and query_length != 1
|
||||
|
||||
# Handle sliding window
|
||||
use_sliding_windows = (
|
||||
_flash_supports_window_size
|
||||
and sliding_window is not None
|
||||
and key_states.shape[1] > sliding_window
|
||||
)
|
||||
window_size = (
|
||||
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
|
||||
)
|
||||
|
||||
# Handle deterministic mode
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if deterministic is None:
|
||||
deterministic = (
|
||||
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
)
|
||||
|
||||
# Call ring flash attention function
|
||||
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=None,
|
||||
deterministic=deterministic,
|
||||
return_attn_probs=False,
|
||||
group=process_group,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
return _flash_attention_forward
|
||||
|
||||
|
||||
def substitute_hf_flash_attn(
|
||||
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
|
||||
):
|
||||
"""
|
||||
Substitute HuggingFace's flash attention implementation with ring-based implementation.
|
||||
|
||||
Args:
|
||||
process_group: PyTorch distributed process group for communication.
|
||||
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
|
||||
attention with.
|
||||
"""
|
||||
try:
|
||||
# Substitute flash attention
|
||||
old_flash_attention_forward = (
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||
)
|
||||
new_flash_attention_forward = create_flash_attn_forward(
|
||||
process_group=process_group, ring_attn_func=ring_attn_func
|
||||
)
|
||||
|
||||
if check_params(old_flash_attention_forward, new_flash_attention_forward):
|
||||
transformers.modeling_flash_attention_utils._flash_attention_forward = (
|
||||
new_flash_attention_forward
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The signature of the new flash attention forward function does not match the old one."
|
||||
)
|
||||
except Exception as exception:
|
||||
raise ValueError(
|
||||
f"The current transformer version {transformers.__version__} is not supported. "
|
||||
"Please use pip install -U transformers to upgrade to the latest version. "
|
||||
"If the code failed with the latest version, "
|
||||
f"please file an issue."
|
||||
) from exception
|
||||
|
||||
# Register with ALL_ATTENTION_FUNCTIONS if available
|
||||
if ALL_ATTENTION_FUNCTIONS is not None:
|
||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||
@@ -6,6 +6,8 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
@@ -16,6 +18,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
@@ -40,7 +43,22 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
BATCH_ZIGZAG = "batch_zigzag"
|
||||
BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
@@ -48,6 +66,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
`batch` function.
|
||||
"""
|
||||
if get_ring_attn_group() is not None:
|
||||
LOG.info("Ring attention already registered, exiting early...")
|
||||
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Detailed logging of group formation
|
||||
rank = dist.get_rank()
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if heads_k_stride is None:
|
||||
heads_k_stride = 1
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
|
||||
)
|
||||
elif ring_attn_func in [
|
||||
RingAttnFunc.BATCH_RING,
|
||||
RingAttnFunc.BATCH_ZIGZAG,
|
||||
RingAttnFunc.BATCH_STRIPE,
|
||||
]:
|
||||
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
|
||||
substitute_hf_flash_attn,
|
||||
)
|
||||
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||
)
|
||||
substitute_hf_flash_attn(
|
||||
process_group=get_ring_attn_group(),
|
||||
ring_attn_func=ring_attn_func,
|
||||
)
|
||||
|
||||
|
||||
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
|
||||
def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
"""
|
||||
Calculate the cumulative sequence lengths for the current forward pass and pass the
|
||||
value to the substituted `ring_flash_attn`.
|
||||
|
||||
Args:
|
||||
batch: A dictionary with a batch of data. May or may not contain `position_ids`
|
||||
data; if not, we compute it.
|
||||
position_ids: Optional tensor of position IDs (for sample packed data).
|
||||
"""
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
position_ids = batch.get("position_ids")
|
||||
if position_ids is None:
|
||||
seq_len = input_ids.shape[1]
|
||||
position_ids = torch.arange(
|
||||
0, seq_len, dtype=torch.long, device=input_ids.device
|
||||
).unsqueeze(0)
|
||||
|
||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
@@ -4,7 +4,7 @@ includes logic for handling sequence parallelism collation.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -13,6 +13,7 @@ from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,14 +54,15 @@ class DataCollatorForSeq2Seq:
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
model: Any | None = None
|
||||
padding: bool | str | PaddingStrategy = True
|
||||
max_length: int | None = None
|
||||
pad_to_multiple_of: int | None = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_degree: int = 1
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_degree > 1:
|
||||
@@ -157,19 +159,41 @@ class DataCollatorForSeq2Seq:
|
||||
Sliced batch dictionary.
|
||||
"""
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].shape[1]
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Update params for ring attention calculation
|
||||
update_ring_attn_params(batch=batch)
|
||||
# Update params for varlen ring attention calculation
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||
for key in keys_to_slice:
|
||||
if key in batch:
|
||||
batch[key] = batch[key][:, start:end]
|
||||
for key in batch:
|
||||
if batch[key].size(1) == total_seq_len:
|
||||
if self.ring_attn_func in [
|
||||
RingAttnFunc.VARLEN_LLAMA3,
|
||||
RingAttnFunc.BATCH_RING,
|
||||
]:
|
||||
batch[key] = (
|
||||
batch[key]
|
||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
||||
.contiguous()
|
||||
)
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
||||
|
||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
||||
selected_chunks = [
|
||||
chunks[self.local_rank],
|
||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
||||
]
|
||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
||||
# TODO(djsaunde): This doesn't seem to work as expected
|
||||
# Split into striped data and stack
|
||||
tensor = torch.stack(
|
||||
batch[key].split(self.local_world_size, dim=1),
|
||||
dim=1,
|
||||
).transpose(1, 2)
|
||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@@ -655,6 +655,7 @@ class ModelLoader:
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
ring_attn_func=self.cfg.ring_attn_func,
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
@@ -1119,7 +1120,7 @@ class ModelLoader:
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def ajust_model_config(self) -> None:
|
||||
def adjust_model_config(self) -> None:
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
@@ -1279,7 +1280,7 @@ class ModelLoader:
|
||||
else:
|
||||
self.model.tie_weights()
|
||||
|
||||
self.ajust_model_config()
|
||||
self.adjust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
|
||||
@@ -259,6 +259,7 @@ class AxolotlInputConfig(
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: str | None = None
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1147,7 +1148,7 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@field_validator("sequence_parallel_degree", mode="before")
|
||||
@field_validator("sequence_parallel_degree", mode="after")
|
||||
@classmethod
|
||||
def check_sequence_parallel_degree(cls, value, info):
|
||||
if not value:
|
||||
@@ -1159,9 +1160,12 @@ class AxolotlInputConfig(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if not info.data["micro_batch_size"] == 1:
|
||||
if (
|
||||
info.data.get("sample_packing")
|
||||
and not info.data["micro_batch_size"] == 1
|
||||
):
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 "
|
||||
"micro_batch_size must be set to 1 when sample_packing is enabled"
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
@@ -1188,6 +1192,34 @@ class AxolotlInputConfig(
|
||||
|
||||
return value
|
||||
|
||||
@field_validator("ring_attn_func", mode="after")
|
||||
@classmethod
|
||||
def check_ring_attn_func(cls, value, info):
|
||||
if not info.data.get("sequence_parallel_degree", 1) > 1:
|
||||
return value
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||
|
||||
if value is not None:
|
||||
# Set the ring attention function if passed in config
|
||||
valid_funcs = list(RingAttnFunc)
|
||||
if value in valid_funcs:
|
||||
value = RingAttnFunc(value)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"ring_attn_func: {value} must be one of {valid_funcs}"
|
||||
)
|
||||
else:
|
||||
# Default ring attention function selection
|
||||
sample_packing = info.data.get("sample_packing")
|
||||
value = (
|
||||
RingAttnFunc.VARLEN_LLAMA3
|
||||
if sample_packing
|
||||
else RingAttnFunc.BATCH_RING
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user