progress on ring attn impl
This commit is contained in:
@@ -758,9 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"kd_top_k_before_softmax"
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
|
||||
training_arguments_kwargs[
|
||||
"sequence_parallel_size"
|
||||
] = self.cfg.sequence_parallel_size
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -793,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
@@ -845,9 +849,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if self.cfg.micro_batch_size > 1:
|
||||
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
@@ -9,11 +9,14 @@ import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
@@ -25,6 +28,7 @@ from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
@@ -796,6 +800,58 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
@override
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
|
||||
parameters needed for the ring flash attention implementation we're using.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
The model to train.
|
||||
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if "attention_mask" in inputs:
|
||||
# Calculate sequence lengths from attention mask
|
||||
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
||||
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
|
||||
else:
|
||||
# Assume all sequences are the same length if no mask is provided
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
seq_lens = [seq_len] * batch_size
|
||||
total_seq_len = batch_size * seq_len
|
||||
|
||||
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
|
||||
|
||||
return super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn.
|
||||
"""
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
|
||||
@@ -206,6 +206,13 @@ class AxolotlTrainingMixins:
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
sequence_parallel_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "The number of workers to use in sequence parallelism"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -213,8 +220,8 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||
default value so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ Materialization-aware gradient checkpointing monkey patch.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import (
|
||||
_get_autocast_kwargs,
|
||||
check_backward_validity,
|
||||
@@ -16,10 +14,12 @@ from torch.utils.checkpoint import (
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
BaseModelOutputWithPast,
|
||||
LlamaDecoderLayer,
|
||||
LlamaModel,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
from .async_communication import initialize_distributed, reset_global_memory_buffer
|
||||
from .async_communication import initialize_distributed
|
||||
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
|
||||
|
||||
# define a global buffer to save flash attention outputs
|
||||
@@ -749,7 +749,6 @@ def forward(
|
||||
|
||||
def apply_dist_flash_attn_monkey_patch_llama():
|
||||
initialize_distributed()
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||
llama_layer_forward
|
||||
)
|
||||
|
||||
LlamaModel.forward = forward
|
||||
LlamaDecoderLayer.forward = llama_layer_forward
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention,
|
||||
MistralDecoderLayer,
|
||||
)
|
||||
|
||||
|
||||
def new_flash_attn_forward(
|
||||
@@ -18,6 +22,10 @@ def new_flash_attn_forward(
|
||||
softmax_scale=None,
|
||||
use_sliding_windows=False,
|
||||
):
|
||||
assert (
|
||||
self.config._attn_implementation == "flash_attention_2"
|
||||
), "Only Flash Attention is supported."
|
||||
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
@@ -48,26 +56,19 @@ def new_decoder_forward(
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
assert isinstance(
|
||||
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
||||
) or isinstance(
|
||||
assert isinstance(self.self_attn, LlamaAttention) or isinstance(
|
||||
self.self_attn,
|
||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
||||
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
||||
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
MistralAttention,
|
||||
), "Llama and Mistral attention only are supported."
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@@ -75,6 +76,7 @@ def new_decoder_forward(
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@@ -86,29 +88,19 @@ def new_decoder_forward(
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def apply_zigzag_ring_attn_monkey_patch_llama():
|
||||
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
||||
new_flash_attn_forward
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
||||
new_decoder_forward
|
||||
)
|
||||
# LlamaAttention._flash_attention_forward = new_flash_attn_forward
|
||||
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||
LlamaDecoderLayer.forward = new_decoder_forward
|
||||
|
||||
|
||||
def apply_zigzag_ring_attn_monkey_patch_mistral():
|
||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
|
||||
new_flash_attn_forward
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
|
||||
new_decoder_forward
|
||||
)
|
||||
# MistralAttention._flash_attention_forward = new_flash_attn_forward
|
||||
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||
MistralDecoderLayer.forward = new_decoder_forward
|
||||
|
||||
@@ -11,6 +11,7 @@ from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
from axolotl.utils.ring_attn import register_ring_attn
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
@@ -548,18 +549,17 @@ class ModelLoader:
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
from axolotl.integrations.easy_context import (
|
||||
apply_seq_parallel_monkey_patch,
|
||||
)
|
||||
# from axolotl.integrations.easy_context import (
|
||||
# apply_seq_parallel_monkey_patch,
|
||||
# )
|
||||
|
||||
method = self.cfg.sequence_parallel_mode
|
||||
model_type = self.cfg.model_type
|
||||
# method = self.cfg.sequence_parallel_mode
|
||||
# model_type = self.cfg.model_config_type
|
||||
|
||||
# Apply the monkey patch
|
||||
apply_seq_parallel_monkey_patch(method, model_type)
|
||||
|
||||
# Ensure flash attention is enabled when loading the model
|
||||
self.cfg.attn_implementation = "flash_attention_2"
|
||||
# # Apply the monkey patch
|
||||
# apply_seq_parallel_monkey_patch(method, model_type)
|
||||
|
||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
|
||||
40
src/axolotl/utils/ring_attn.py
Normal file
40
src/axolotl/utils/ring_attn.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch.distributed as dist
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
def get_ring_attn_group():
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group):
|
||||
global RING_ATTN_GROUP
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_size):
|
||||
"""
|
||||
Create ring attention group and substitute flash attention with ring flash
|
||||
attention.
|
||||
"""
|
||||
if sequence_parallel_size == 1:
|
||||
return
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert world_size % sequence_parallel_size == 0, \
|
||||
f"sequence_parallel_size ({sequence_parallel_size}) " \
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
|
||||
for i in range(world_size // sequence_parallel_size):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_size,
|
||||
(i + 1) * sequence_parallel_size,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
if dist.get_rank() in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)
|
||||
@@ -8,6 +8,7 @@ from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
@@ -346,7 +347,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
@@ -356,7 +357,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if cfg.sequence_parallel_size > 1:
|
||||
train_dataset.map(
|
||||
prepare_seq_parallel_inputs,
|
||||
"dist_flash_attn",
|
||||
lambda batch: batch["input_ids"],
|
||||
lambda batch: batch["position_ids"],
|
||||
lambda batch: batch["target_ids"],
|
||||
accelerator.process_index,
|
||||
accelerator.num_processes,
|
||||
accelerator.device,
|
||||
)
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids,
|
||||
|
||||
Reference in New Issue
Block a user