progress on ring attn impl

This commit is contained in:
Dan Saunders
2025-03-04 21:31:11 +00:00
parent 3f8a43cab6
commit bd952de9d2
10 changed files with 174 additions and 60 deletions

View File

@@ -758,9 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs[
"sequence_parallel_size"
] = self.cfg.sequence_parallel_size
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -793,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
@@ -845,9 +849,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None

View File

@@ -9,11 +9,14 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Dict, Literal, Optional
from typing import Any, Dict, Literal, Optional
from typing_extensions import override
import torch
import torch.nn.functional as F
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from ring_flash_attn import update_ring_flash_attn_params
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
@@ -25,6 +28,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.ring_attn import get_ring_attn_group
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
@@ -796,6 +800,58 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
@override
def training_step(
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
parameters needed for the ring flash attention implementation we're using.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
if self.args.sequence_parallel_size > 1:
if "attention_mask" in inputs:
# Calculate sequence lengths from attention mask
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
else:
# Assume all sequences are the same length if no mask is provided
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
seq_lens = [seq_len] * batch_size
total_seq_len = batch_size * seq_len
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
return super().training_step(model, inputs, num_items_in_batch)
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn.
"""
cu_seqlens = torch.cumsum(
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
class AxolotlMambaTrainer(AxolotlTrainer):

View File

@@ -206,6 +206,13 @@ class AxolotlTrainingMixins:
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
sequence_parallel_size: Optional[int] = field(
default=1,
metadata={
"help": "The number of workers to use in sequence parallelism"
},
)
@dataclass
@@ -213,8 +220,8 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
This code is duplicated due to HF TrainingArguments not setting output_dir with a
default value so it can't be used as a mixin.
"""

View File

@@ -4,9 +4,7 @@ Materialization-aware gradient checkpointing monkey patch.
from typing import List, Optional, Tuple
import torch
import transformers
from einops import rearrange
from torch import nn
from torch.utils.checkpoint import (
_get_autocast_kwargs,
check_backward_validity,
@@ -16,10 +14,12 @@ from torch.utils.checkpoint import (
)
from transformers.models.llama.modeling_llama import (
BaseModelOutputWithPast,
LlamaDecoderLayer,
LlamaModel,
apply_rotary_pos_emb,
)
from .async_communication import initialize_distributed, reset_global_memory_buffer
from .async_communication import initialize_distributed
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
# define a global buffer to save flash attention outputs
@@ -749,7 +749,6 @@ def forward(
def apply_dist_flash_attn_monkey_patch_llama():
initialize_distributed()
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
llama_layer_forward
)
LlamaModel.forward = forward
LlamaDecoderLayer.forward = llama_layer_forward

View File

@@ -1,10 +1,14 @@
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
import transformers
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
)
def new_flash_attn_forward(
@@ -18,6 +22,10 @@ def new_flash_attn_forward(
softmax_scale=None,
use_sliding_windows=False,
):
assert (
self.config._attn_implementation == "flash_attention_2"
), "Only Flash Attention is supported."
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
@@ -48,26 +56,19 @@ def new_decoder_forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
assert isinstance(
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
) or isinstance(
assert isinstance(self.self_attn, LlamaAttention) or isinstance(
self.self_attn,
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
MistralAttention,
), "Llama and Mistral attention only are supported."
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@@ -75,6 +76,7 @@ def new_decoder_forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
@@ -86,29 +88,19 @@ def new_decoder_forward(
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def apply_zigzag_ring_attn_monkey_patch_llama():
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
new_decoder_forward
)
# LlamaAttention._flash_attention_forward = new_flash_attn_forward
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
LlamaDecoderLayer.forward = new_decoder_forward
def apply_zigzag_ring_attn_monkey_patch_mistral():
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
new_flash_attn_forward
)
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
new_decoder_forward
)
# MistralAttention._flash_attention_forward = new_flash_attn_forward
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
MistralDecoderLayer.forward = new_decoder_forward

View File

@@ -11,6 +11,7 @@ from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict
from axolotl.utils.ring_attn import register_ring_attn
import bitsandbytes as bnb
import torch
import transformers
@@ -548,18 +549,17 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1:
from axolotl.integrations.easy_context import (
apply_seq_parallel_monkey_patch,
)
# from axolotl.integrations.easy_context import (
# apply_seq_parallel_monkey_patch,
# )
method = self.cfg.sequence_parallel_mode
model_type = self.cfg.model_type
# method = self.cfg.sequence_parallel_mode
# model_type = self.cfg.model_config_type
# Apply the monkey patch
apply_seq_parallel_monkey_patch(method, model_type)
# Ensure flash attention is enabled when loading the model
self.cfg.attn_implementation = "flash_attention_2"
# # Apply the monkey patch
# apply_seq_parallel_monkey_patch(method, model_type)
register_ring_attn(self.cfg.sequence_parallel_size)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -0,0 +1,40 @@
import torch.distributed as dist
from ring_flash_attn import substitute_hf_flash_attn
RING_ATTN_GROUP = None
def get_ring_attn_group():
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group):
global RING_ATTN_GROUP
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_size):
"""
Create ring attention group and substitute flash attention with ring flash
attention.
"""
if sequence_parallel_size == 1:
return
world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, \
f"sequence_parallel_size ({sequence_parallel_size}) " \
f"must evenly divide world_size ({world_size})"
for i in range(world_size // sequence_parallel_size):
ring_attn_ranks = list(
range(
i * sequence_parallel_size,
(i + 1) * sequence_parallel_size,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
if dist.get_rank() in ring_attn_ranks:
set_ring_attn_group(group)
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)

View File

@@ -8,6 +8,7 @@ from contextlib import contextmanager
from functools import partial
from typing import List, Optional
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
import numpy as np
import torch
import torch.cuda
@@ -346,7 +347,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +357,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing is not False:
if cfg.sequence_parallel_size > 1:
train_dataset.map(
prepare_seq_parallel_inputs,
"dist_flash_attn",
lambda batch: batch["input_ids"],
lambda batch: batch["position_ids"],
lambda batch: batch["target_ids"],
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,