progress on ring attn impl

This commit is contained in:
Dan Saunders
2025-03-04 21:31:11 +00:00
parent 3f8a43cab6
commit bd952de9d2
10 changed files with 174 additions and 60 deletions

5
=0.1.4 Normal file
View File

@@ -0,0 +1,5 @@
Collecting ring-flash-attn
Downloading ring_flash_attn-0.1.4-py3-none-any.whl.metadata (7.3 kB)
Downloading ring_flash_attn-0.1.4-py3-none-any.whl (24 kB)
Installing collected packages: ring-flash-attn
Successfully installed ring-flash-attn-0.1.4

View File

@@ -67,4 +67,5 @@ axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3 axolotl-contribs-mit==0.0.3
# for sequence parallelism # for sequence parallelism
yunchang.=0.6.0
ring-flash-attn>=0.1.4 ring-flash-attn>=0.1.4

View File

@@ -758,9 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_zscore_base_temp self.cfg.kd_zscore_base_temp
) )
if self.cfg.kd_top_k_before_softmax is not None: if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = ( training_arguments_kwargs[
self.cfg.kd_top_k_before_softmax "kd_top_k_before_softmax"
) ] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs[
"sequence_parallel_size"
] = self.cfg.sequence_parallel_size
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -845,9 +849,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False: if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None

View File

@@ -9,11 +9,14 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
from typing import Dict, Literal, Optional from typing import Any, Dict, Literal, Optional
from typing_extensions import override
import torch import torch
import torch.nn.functional as F
from datasets import Dataset from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from ring_flash_attn import update_ring_flash_attn_params
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
@@ -25,6 +28,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.ring_attn import get_ring_attn_group
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR, RexLR,
@@ -797,6 +801,58 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs) return super()._save_checkpoint(model, trial, **kwargs)
@override
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
parameters needed for the ring flash attention implementation we're using.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
if self.args.sequence_parallel_size > 1:
if "attention_mask" in inputs:
# Calculate sequence lengths from attention mask
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
else:
# Assume all sequences are the same length if no mask is provided
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
seq_lens = [seq_len] * batch_size
total_seq_len = batch_size * seq_len
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
return super().training_step(model, inputs, num_items_in_batch)
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
"""
cu_seqlens = torch.cumsum(
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
""" """

View File

@@ -207,14 +207,21 @@ class AxolotlTrainingMixins:
}, },
) )
sequence_parallel_size: Optional[int] = field(
default=1,
metadata={
"help": "The number of workers to use in sequence parallelism"
},
)
@dataclass @dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value This code is duplicated due to HF TrainingArguments not setting output_dir with a
so it can't be used as a mixin. default value so it can't be used as a mixin.
""" """

View File

@@ -4,9 +4,7 @@ Materialization-aware gradient checkpointing monkey patch.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import transformers
from einops import rearrange from einops import rearrange
from torch import nn
from torch.utils.checkpoint import ( from torch.utils.checkpoint import (
_get_autocast_kwargs, _get_autocast_kwargs,
check_backward_validity, check_backward_validity,
@@ -16,10 +14,12 @@ from torch.utils.checkpoint import (
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
LlamaDecoderLayer,
LlamaModel,
apply_rotary_pos_emb, apply_rotary_pos_emb,
) )
from .async_communication import initialize_distributed, reset_global_memory_buffer from .async_communication import initialize_distributed
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
# define a global buffer to save flash attention outputs # define a global buffer to save flash attention outputs
@@ -749,7 +749,6 @@ def forward(
def apply_dist_flash_attn_monkey_patch_llama(): def apply_dist_flash_attn_monkey_patch_llama():
initialize_distributed() initialize_distributed()
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( LlamaModel.forward = forward
llama_layer_forward LlamaDecoderLayer.forward = llama_layer_forward
)

View File

@@ -1,10 +1,14 @@
import warnings from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
)
def new_flash_attn_forward( def new_flash_attn_forward(
@@ -18,6 +22,10 @@ def new_flash_attn_forward(
softmax_scale=None, softmax_scale=None,
use_sliding_windows=False, use_sliding_windows=False,
): ):
assert (
self.config._attn_implementation == "flash_attention_2"
), "Only Flash Attention is supported."
if not self._flash_attn_uses_top_left_mask: if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal causal = self.is_causal
else: else:
@@ -48,26 +56,19 @@ def new_decoder_forward(
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
assert isinstance( assert isinstance(self.self_attn, LlamaAttention) or isinstance(
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
) or isinstance(
self.self_attn, self.self_attn,
transformers.models.mistral.modeling_mistral.MistralFlashAttention2, MistralAttention,
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." ), "Llama and Mistral attention only are supported."
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
@@ -75,6 +76,7 @@ def new_decoder_forward(
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
@@ -86,29 +88,19 @@ def new_decoder_forward(
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
if output_attentions: if output_attentions:
outputs += (self_attn_weights,) outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs return outputs
def apply_zigzag_ring_attn_monkey_patch_llama(): def apply_zigzag_ring_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( # LlamaAttention._flash_attention_forward = new_flash_attn_forward
new_flash_attn_forward ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
) LlamaDecoderLayer.forward = new_decoder_forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)
def apply_zigzag_ring_attn_monkey_patch_mistral(): def apply_zigzag_ring_attn_monkey_patch_mistral():
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( # MistralAttention._flash_attention_forward = new_flash_attn_forward
new_flash_attn_forward ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
) MistralDecoderLayer.forward = new_decoder_forward
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
new_decoder_forward
)

View File

@@ -11,6 +11,7 @@ from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
from axolotl.utils.ring_attn import register_ring_attn
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
@@ -548,18 +549,17 @@ class ModelLoader:
patch_self_attn_lora(self.cfg) patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1: if self.cfg.sequence_parallel_size > 1:
from axolotl.integrations.easy_context import ( # from axolotl.integrations.easy_context import (
apply_seq_parallel_monkey_patch, # apply_seq_parallel_monkey_patch,
) # )
method = self.cfg.sequence_parallel_mode # method = self.cfg.sequence_parallel_mode
model_type = self.cfg.model_type # model_type = self.cfg.model_config_type
# Apply the monkey patch # # Apply the monkey patch
apply_seq_parallel_monkey_patch(method, model_type) # apply_seq_parallel_monkey_patch(method, model_type)
# Ensure flash attention is enabled when loading the model register_ring_attn(self.cfg.sequence_parallel_size)
self.cfg.attn_implementation = "flash_attention_2"
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):

View File

@@ -0,0 +1,40 @@
import torch.distributed as dist
from ring_flash_attn import substitute_hf_flash_attn
RING_ATTN_GROUP = None
def get_ring_attn_group():
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group):
global RING_ATTN_GROUP
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_size):
"""
Create ring attention group and substitute flash attention with ring flash
attention.
"""
if sequence_parallel_size == 1:
return
world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, \
f"sequence_parallel_size ({sequence_parallel_size}) " \
f"must evenly divide world_size ({world_size})"
for i in range(world_size // sequence_parallel_size):
ring_attn_ranks = list(
range(
i * sequence_parallel_size,
(i + 1) * sequence_parallel_size,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
if dist.get_rank() in ring_attn_ranks:
set_ring_attn_group(group)
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)

View File

@@ -8,6 +8,7 @@ from contextlib import contextmanager
from functools import partial from functools import partial
from typing import List, Optional from typing import List, Optional
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
@@ -346,7 +347,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +357,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
if cfg.eval_sample_packing is not False: if cfg.sequence_parallel_size > 1:
train_dataset.map(
prepare_seq_parallel_inputs,
"dist_flash_attn",
lambda batch: batch["input_ids"],
lambda batch: batch["position_ids"],
lambda batch: batch["target_ids"],
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,