Compare commits
74 Commits
muon-valid
...
82d04ea060
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82d04ea060 | ||
|
|
0ef1f011fe | ||
|
|
c0a1d205c7 | ||
|
|
d0e739da24 | ||
|
|
3f6be519d5 | ||
|
|
adcbc7459b | ||
|
|
470ba65c44 | ||
|
|
8e1adc154d | ||
|
|
e5b36900e4 | ||
|
|
9f6c89b12b | ||
|
|
b0871c8d3b | ||
|
|
d3ea379a23 | ||
|
|
0ebab63309 | ||
|
|
e98581f6f5 | ||
|
|
b832b11c8f | ||
|
|
b692d394b1 | ||
|
|
2319e5276d | ||
|
|
9a43a0925d | ||
|
|
10de67e8ea | ||
|
|
fa7355404c | ||
|
|
907424a2e8 | ||
|
|
3f4fd3c1eb | ||
|
|
48c3c47071 | ||
|
|
3ed9c117fb | ||
|
|
84960003ed | ||
|
|
93a268e43d | ||
|
|
065f6d477e | ||
|
|
96ad741cd5 | ||
|
|
ba88bc7840 | ||
|
|
b31796a681 | ||
|
|
5ca57cb55a | ||
|
|
0149de7fb0 | ||
|
|
8c34c65181 | ||
|
|
555aa5772a | ||
|
|
e8b2789086 | ||
|
|
85752cdfc9 | ||
|
|
f2f23c8041 | ||
|
|
8b3eec7f6e | ||
|
|
bb9bea3110 | ||
|
|
0dd18a3681 | ||
|
|
152e988d3c | ||
|
|
27532825a9 | ||
|
|
06f83a54a5 | ||
|
|
d7b133dc1f | ||
|
|
f3bec17917 | ||
|
|
b7deb5241c | ||
|
|
cee310dcfa | ||
|
|
d1be6e228d | ||
|
|
5f9f77f384 | ||
|
|
b2a34380b3 | ||
|
|
80bfc50d1f | ||
|
|
a5360c172c | ||
|
|
013a9b73fc | ||
|
|
aad62428e0 | ||
|
|
a6f2c5d583 | ||
|
|
dbcd11e533 | ||
|
|
c06a6be915 | ||
|
|
d3a0cb5edb | ||
|
|
8b47e456b0 | ||
|
|
2319ac729c | ||
|
|
f99cae0e7b | ||
|
|
888cd9407f | ||
|
|
bd62d6e10a | ||
|
|
5eae134110 | ||
|
|
b7d27bdfa4 | ||
|
|
da97a21bdc | ||
|
|
e0d4b88598 | ||
|
|
fac059a209 | ||
|
|
9c9ac1cf0b | ||
|
|
2346f21b2b | ||
|
|
0b47281f51 | ||
|
|
543daaf46f | ||
|
|
bcd9ad44e0 | ||
|
|
61ad375bf4 |
@@ -79,6 +79,7 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
|
|||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
|
FlexBatchSamplerDataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
@@ -816,6 +817,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
FlexBatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
@@ -827,7 +829,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.flex_attention is True:
|
||||||
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
146
src/axolotl/monkeypatch/flex_attn.py
Normal file
146
src/axolotl/monkeypatch/flex_attn.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
|
||||||
|
"""
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
from torch.nn.attention.flex_attention import (
|
||||||
|
create_block_mask as create_block_causal_mask_flex,
|
||||||
|
)
|
||||||
|
|
||||||
|
_MaskType = Union[torch.Tensor, BlockMask]
|
||||||
|
|
||||||
|
|
||||||
|
def create_block_causal_mask(
|
||||||
|
seq_lens: list[torch.Tensor], max_seq_len: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Given a batch tensor of seq lens defining the lengths of samples in each pack,
|
||||||
|
Construct a 2D block causal mask for each pack in the batch. For example, if
|
||||||
|
a single sample's seq_lens is [3, 2, 1], the mask would be::
|
||||||
|
|
||||||
|
mask = [
|
||||||
|
[1, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 1, 0, 0],
|
||||||
|
[0, 0, 0, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0, 1],
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
||||||
|
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
||||||
|
across packs.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len).
|
||||||
|
"""
|
||||||
|
batch_block_attn_masks = []
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
for sample_idx in range(batch_size):
|
||||||
|
block_attn_masks = [
|
||||||
|
torch.trilu( # torch.tril(
|
||||||
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
|
||||||
|
)
|
||||||
|
for seq_len in seq_lens[sample_idx]
|
||||||
|
]
|
||||||
|
|
||||||
|
"""residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
|
||||||
|
block_attn_masks.append(
|
||||||
|
torch.tril(
|
||||||
|
torch.ones(
|
||||||
|
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)"""
|
||||||
|
|
||||||
|
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
|
||||||
|
|
||||||
|
return torch.stack(batch_block_attn_masks)[:, None, :, :]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_document_ids_from_seq_lens(
|
||||||
|
seq_lens: list[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
|
||||||
|
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
||||||
|
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
||||||
|
across packs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Document IDs of shape (batch_size, max_seq_len).
|
||||||
|
"""
|
||||||
|
batch_size = len(seq_lens)
|
||||||
|
batch_document_ids = []
|
||||||
|
for sample_idx in range(batch_size):
|
||||||
|
# We assume seq lens sum to max seq lens, so document_ids should be of
|
||||||
|
# shape (max_seq_len, )
|
||||||
|
document_ids = torch.cat(
|
||||||
|
[
|
||||||
|
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
|
||||||
|
for i, seq_len in enumerate(seq_lens[sample_idx])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
batch_document_ids.append(document_ids)
|
||||||
|
batch_document_ids = torch.stack(batch_document_ids)
|
||||||
|
return batch_document_ids
|
||||||
|
|
||||||
|
|
||||||
|
def packed_block_causal_mask(
|
||||||
|
seq_lens: list[torch.Tensor], totalseqlens: list[int]
|
||||||
|
) -> _MaskType:
|
||||||
|
"""
|
||||||
|
Create a block causal document mask for a batch of packed sequences. If
|
||||||
|
flex attention is supported by the current hardware, block causal logic and
|
||||||
|
passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||||
|
The resultant BlockMask is a compressed representation of the full block causal
|
||||||
|
mask. If on an older version, a standard 2D block causal mask is created and returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
|
||||||
|
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
|
||||||
|
across packs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
document_ids = _get_document_ids_from_seq_lens(seq_lens)
|
||||||
|
batch_size , max_seq_len = document_ids.shape
|
||||||
|
document_ids = document_ids.to("cuda")
|
||||||
|
totalseqlens = totalseqlens.to("cuda")
|
||||||
|
|
||||||
|
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||||
|
# that determines which elements of QK^T should be included in the attention
|
||||||
|
# computation prior to the softmax. For sample packing, we need both the
|
||||||
|
# logic for both causal mask and document mask. See PyTorch's official
|
||||||
|
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
|
||||||
|
def mask_mod(b, h, q_idx, kv_idx):
|
||||||
|
"""
|
||||||
|
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||||
|
and a block diagonal document mask.
|
||||||
|
|
||||||
|
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||||
|
for an illustration.
|
||||||
|
"""
|
||||||
|
causal_mask = q_idx >= kv_idx
|
||||||
|
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
|
||||||
|
finite_mask = q_idx < totalseqlens[b]
|
||||||
|
return causal_mask & document_mask & finite_mask
|
||||||
|
|
||||||
|
return create_block_causal_mask_flex(
|
||||||
|
mask_mod,
|
||||||
|
batch_size,
|
||||||
|
None,
|
||||||
|
max_seq_len,
|
||||||
|
max_seq_len,
|
||||||
|
device="cuda",
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
@@ -95,6 +95,103 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_packed_mask_from_pos_ids(position_ids):
|
||||||
|
if len(position_ids.shape) == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
device = position_ids.device
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, row in enumerate(position_ids):
|
||||||
|
# Count the number of consecutive zeros from the right side
|
||||||
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||||
|
|
||||||
|
# Adjust the row to exclude padding
|
||||||
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
seq_starts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([True], dtype=torch.bool, device=device),
|
||||||
|
adjusted_row[1:] == 0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence starts
|
||||||
|
start_indices = torch.cat(
|
||||||
|
[
|
||||||
|
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
||||||
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
|
# Append the padding length to the sequence lengths
|
||||||
|
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
|
||||||
|
for i, seq_len in enumerate(seq_lengths):
|
||||||
|
start_id = start_indices[i]
|
||||||
|
doc_mask[start_id : start_id + seq_len] = (
|
||||||
|
(i+1) * doc_mask[start_id : start_id + seq_len]
|
||||||
|
)
|
||||||
|
if padding_length:
|
||||||
|
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
||||||
|
|
||||||
|
results.append(doc_mask)
|
||||||
|
|
||||||
|
return torch.stack(results)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seqlens_from_pos_ids(position_ids):
|
||||||
|
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
||||||
|
if len(position_ids.shape) == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
max_seq_len = position_ids.shape[1]
|
||||||
|
|
||||||
|
device = position_ids.device
|
||||||
|
results = []
|
||||||
|
totalseqlens = []
|
||||||
|
|
||||||
|
for row in position_ids:
|
||||||
|
# Count the number of consecutive zeros from the right side
|
||||||
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||||
|
|
||||||
|
# Adjust the row to exclude padding
|
||||||
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
seq_starts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([True], dtype=torch.bool, device=device),
|
||||||
|
adjusted_row[1:] == 0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence starts
|
||||||
|
start_indices = torch.cat(
|
||||||
|
[
|
||||||
|
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
||||||
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
|
# Append the padding length to the sequence lengths
|
||||||
|
if padding_length:
|
||||||
|
seq_lengths = torch.cat(
|
||||||
|
[
|
||||||
|
seq_lengths,
|
||||||
|
torch.tensor(
|
||||||
|
[len(row) - torch.sum(seq_lengths)],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(seq_lengths)
|
||||||
|
totalseqlens.append(len(adjusted_row))
|
||||||
|
|
||||||
|
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
@@ -176,7 +273,10 @@ def mask_2d_to_4d(
|
|||||||
when they attend to each other within that sequence.
|
when they attend to each other within that sequence.
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
bsz, src_len = mask.size()
|
|
||||||
|
if len(mask.size()) == 4:
|
||||||
|
return mask
|
||||||
|
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ shared axolotl collators for multipack, mamba, multimodal
|
|||||||
from .batching import ( # noqa: F401
|
from .batching import ( # noqa: F401
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
|
FlexBatchSamplerDataCollatorForSeq2Seq,
|
||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,12 +3,21 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.flex_attn import (
|
||||||
|
create_block_causal_mask,
|
||||||
|
packed_block_causal_mask,
|
||||||
|
)
|
||||||
|
from axolotl.monkeypatch.utils import (
|
||||||
|
get_packed_mask_from_pos_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -151,6 +160,42 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to Flex Attention using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
if not isinstance(features[0], list):
|
||||||
|
features = [features]
|
||||||
|
out_features = [{} for _ in features]
|
||||||
|
for i, features_ in enumerate(features):
|
||||||
|
for feature in features_[0].keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
elif feature == "attention_mask":
|
||||||
|
"""arrays = [
|
||||||
|
i * np.array(item[feature])
|
||||||
|
for i, item in enumerate(features_)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)"""
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
]
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
out = super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
# collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
|
||||||
|
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
|
||||||
|
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
|
||||||
|
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -784,6 +784,7 @@ class AxolotlInputConfig(
|
|||||||
xformers_attention: Optional[bool] = None
|
xformers_attention: Optional[bool] = None
|
||||||
sdp_attention: Optional[bool] = None
|
sdp_attention: Optional[bool] = None
|
||||||
s2_attention: Optional[bool] = None
|
s2_attention: Optional[bool] = None
|
||||||
|
flex_attention: Optional[bool] = None
|
||||||
flash_attention: Optional[bool] = None
|
flash_attention: Optional[bool] = None
|
||||||
flash_attn_cross_entropy: Optional[bool] = None
|
flash_attn_cross_entropy: Optional[bool] = None
|
||||||
flash_attn_rms_norm: Optional[bool] = None
|
flash_attn_rms_norm: Optional[bool] = None
|
||||||
@@ -1679,6 +1680,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_flex_torch_version(cls, data):
|
||||||
|
if (data.get("flex_attention") is not None) and (
|
||||||
|
data.get("flex_attention") is True
|
||||||
|
):
|
||||||
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
|
torch_version = env_capabilities.get("torch_version")
|
||||||
|
|
||||||
|
if torch_version is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
|
||||||
|
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||||
|
raise ValueError(
|
||||||
|
"Flex attention is not supported on torch version < 2.5.1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_auto(cls, data):
|
def check_torch_compile_auto(cls, data):
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and self.cfg.flash_attention
|
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if "auto_map" in self.model_config:
|
if "auto_map" in self.model_config:
|
||||||
@@ -708,7 +708,13 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
if self.cfg.flash_attention:
|
|
||||||
|
if self.cfg.flex_attention:
|
||||||
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"flex_attention"
|
||||||
|
)
|
||||||
|
elif self.cfg.flash_attention:
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
@@ -1100,7 +1106,7 @@ class ModelLoader:
|
|||||||
should_convert = (
|
should_convert = (
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
|
||||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user