Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
f24efd77a1 lint docs 2025-02-12 10:04:01 -05:00
8 changed files with 7 additions and 330 deletions

View File

@@ -22,4 +22,4 @@ description: Frequently asked questions
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.

View File

@@ -79,7 +79,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
@@ -817,7 +816,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
@@ -829,9 +827,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -1,146 +0,0 @@
"""
Taken from https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention_utils.py
"""
from typing import Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
_MaskType = Union[torch.Tensor, BlockMask]
def create_block_causal_mask(
seq_lens: list[torch.Tensor], max_seq_len: int
) -> torch.Tensor:
"""
Given a batch tensor of seq lens defining the lengths of samples in each pack,
Construct a 2D block causal mask for each pack in the batch. For example, if
a single sample's seq_lens is [3, 2, 1], the mask would be::
mask = [
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
]
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Block causal mask of shape (batch_size, max_seq_len, max_seq_len).
"""
batch_block_attn_masks = []
batch_size = len(seq_lens)
for sample_idx in range(batch_size):
block_attn_masks = [
torch.trilu( # torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=seq_len.device)
)
for seq_len in seq_lens[sample_idx]
]
"""residue_len = max_seq_len - torch.sum(seq_lens[sample_idx])
block_attn_masks.append(
torch.tril(
torch.ones(
residue_len, residue_len, dtype=torch.bool, device=seq_lens[sample_idx].device
)
)
)"""
batch_block_attn_masks.append(torch.block_diag(*block_attn_masks))
return torch.stack(batch_block_attn_masks)[:, None, :, :]
def _get_document_ids_from_seq_lens(
seq_lens: list[torch.Tensor],
) -> torch.Tensor:
"""
Convert a batch tensor of seq lens into integer IDs denoting sample ownership.
For example, seq_lens = [2, 3, 1] would return [0, 0, 1, 1, 1, 2].
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
Tensor: Document IDs of shape (batch_size, max_seq_len).
"""
batch_size = len(seq_lens)
batch_document_ids = []
for sample_idx in range(batch_size):
# We assume seq lens sum to max seq lens, so document_ids should be of
# shape (max_seq_len, )
document_ids = torch.cat(
[
torch.full((seq_len,), i, dtype=torch.long, device=seq_len.device)
for i, seq_len in enumerate(seq_lens[sample_idx])
]
)
batch_document_ids.append(document_ids)
batch_document_ids = torch.stack(batch_document_ids)
return batch_document_ids
def packed_block_causal_mask(
seq_lens: list[torch.Tensor], totalseqlens: list[int]
) -> _MaskType:
"""
Create a block causal document mask for a batch of packed sequences. If
flex attention is supported by the current hardware, block causal logic and
passing this into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. If on an older version, a standard 2D block causal mask is created and returned.
Args:
seq_lens (List[torch.Tensor]): Sequence lengths of samples in each pack in the batch,
shape (batch_size, n), where n is the max number of sequences in a pack and can vary
across packs.
Returns:
_MaskType: BlockMask or Tensor if torch version < 2.5.0.
"""
document_ids = _get_document_ids_from_seq_lens(seq_lens)
batch_size , max_seq_len = document_ids.shape
document_ids = document_ids.to("cuda")
totalseqlens = totalseqlens.to("cuda")
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def mask_mod(b, h, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
document_mask = document_ids[b, q_idx] == document_ids[b, kv_idx]
finite_mask = q_idx < totalseqlens[b]
return causal_mask & document_mask & finite_mask
return create_block_causal_mask_flex(
mask_mod,
batch_size,
None,
max_seq_len,
max_seq_len,
device="cuda",
BLOCK_SIZE=512,
)

View File

@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -4,7 +4,6 @@ shared axolotl collators for multipack, mamba, multimodal
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
FlexBatchSamplerDataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)

View File

@@ -3,21 +3,12 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.flex_attn import (
create_block_causal_mask,
packed_block_causal_mask,
)
from axolotl.monkeypatch.utils import (
get_packed_mask_from_pos_ids,
)
@dataclass
class DataCollatorForSeq2Seq:
@@ -160,42 +151,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to Flex Attention using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
elif feature == "attention_mask":
"""arrays = [
i * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)"""
continue
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
out = super().__call__(out_features, return_tensors=return_tensors)
# collated_seq_lens, totalseqlens = get_seqlens_from_pos_ids(out["position_ids"])
# out["attention_mask"] = packed_block_causal_mask(collated_seq_lens, totalseqlens)
out["attention_mask"] = get_packed_mask_from_pos_ids(out["position_ids"])
# out["attention_mask"] = create_block_causal_mask(collated_seq_lens, max_seq_len)
return out
@dataclass
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""

View File

@@ -784,7 +784,6 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
@@ -1680,26 +1679,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.flash_attention
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
@@ -708,13 +708,7 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1106,7 +1100,7 @@ class ModelLoader:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)