diff --git a/=0.1.4 b/=0.1.4 new file mode 100644 index 000000000..0ee5c1a5b --- /dev/null +++ b/=0.1.4 @@ -0,0 +1,5 @@ +Collecting ring-flash-attn + Downloading ring_flash_attn-0.1.4-py3-none-any.whl.metadata (7.3 kB) +Downloading ring_flash_attn-0.1.4-py3-none-any.whl (24 kB) +Installing collected packages: ring-flash-attn +Successfully installed ring-flash-attn-0.1.4 diff --git a/requirements.txt b/requirements.txt index 02ecd56f7..7aad4ff9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,4 +67,5 @@ axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 # for sequence parallelism +yunchang.=0.6.0 ring-flash-attn>=0.1.4 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 83bfb1c83..97de5a92d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -758,9 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_zscore_base_temp ) if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs["kd_top_k_before_softmax"] = ( - self.cfg.kd_top_k_before_softmax - ) + training_arguments_kwargs[ + "kd_top_k_before_softmax" + ] = self.cfg.kd_top_k_before_softmax + + training_arguments_kwargs[ + "sequence_parallel_size" + ] = self.cfg.sequence_parallel_size if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -793,7 +797,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.reward_model: data_collator_kwargs["max_length"] = self.cfg.sequence_len - + trainer_cls = self._get_trainer_cls() trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls @@ -845,9 +849,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs ): if training_args.pretraining: - if self.cfg.pretraining_sample_concatenation is False: - return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) - if self.cfg.micro_batch_size > 1: + if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1: return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 6570db967..35cf2f30a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -9,11 +9,14 @@ import logging import os from collections import defaultdict from functools import wraps -from typing import Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional +from typing_extensions import override import torch +import torch.nn.functional as F from datasets import Dataset from peft.optimizers import create_loraplus_optimizer +from ring_flash_attn import update_ring_flash_attn_params from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -25,6 +28,7 @@ from trl.trainer.utils import pad_to_length from axolotl.integrations.base import BaseOptimizerFactory from axolotl.monkeypatch.relora import ReLoRAScheduler +from axolotl.utils.ring_attn import get_ring_attn_group from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( RexLR, @@ -796,6 +800,58 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) return super()._save_checkpoint(model, trial, **kwargs) + + @override + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None + ) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Note: we are subclassing `transformers.trainer.Trainer` in order to compute + parameters needed for the ring flash attention implementation we're using. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + if self.args.sequence_parallel_size > 1: + if "attention_mask" in inputs: + # Calculate sequence lengths from attention mask + seq_lens = inputs["attention_mask"].sum(dim=1).tolist() + total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1] + else: + # Assume all sequences are the same length if no mask is provided + batch_size = inputs["input_ids"].shape[0] + seq_len = inputs["input_ids"].shape[1] + seq_lens = [seq_len] * batch_size + total_seq_len = batch_size * seq_len + + self._update_ring_flash_attn_params(seq_lens, total_seq_len) + + return super().training_step(model, inputs, num_items_in_batch) + + def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): + """ + Calculate the cu_seqlens for the current forward pass and pass the value to + the substituted ring_flash_attn. + """ + cu_seqlens = torch.cumsum( + torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), + dim=-1, + dtype=torch.int32, + ) + cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) + + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) class AxolotlMambaTrainer(AxolotlTrainer): diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 34a79e646..2d3b0dda1 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -206,6 +206,13 @@ class AxolotlTrainingMixins: "help": "Whether to apply top_k_before_softmax to the logits when using KD" }, ) + + sequence_parallel_size: Optional[int] = field( + default=1, + metadata={ + "help": "The number of workers to use in sequence parallelism" + }, + ) @dataclass @@ -213,8 +220,8 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value - so it can't be used as a mixin. + This code is duplicated due to HF TrainingArguments not setting output_dir with a + default value so it can't be used as a mixin. """ diff --git a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py index 35e680774..38d9f7f00 100644 --- a/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py +++ b/src/axolotl/integrations/easy_context/dist_flash_attn/monkey_patch.py @@ -4,9 +4,7 @@ Materialization-aware gradient checkpointing monkey patch. from typing import List, Optional, Tuple import torch -import transformers from einops import rearrange -from torch import nn from torch.utils.checkpoint import ( _get_autocast_kwargs, check_backward_validity, @@ -16,10 +14,12 @@ from torch.utils.checkpoint import ( ) from transformers.models.llama.modeling_llama import ( BaseModelOutputWithPast, + LlamaDecoderLayer, + LlamaModel, apply_rotary_pos_emb, ) -from .async_communication import initialize_distributed, reset_global_memory_buffer +from .async_communication import initialize_distributed from .lightseq_async_attn import _lightseq_backward, _lightseq_forward # define a global buffer to save flash attention outputs @@ -749,7 +749,6 @@ def forward( def apply_dist_flash_attn_monkey_patch_llama(): initialize_distributed() - transformers.models.llama.modeling_llama.LlamaModel.forward = forward - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - llama_layer_forward - ) + + LlamaModel.forward = forward + LlamaDecoderLayer.forward = llama_layer_forward diff --git a/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py b/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py index 67dd05631..ddf7dd292 100644 --- a/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py +++ b/src/axolotl/integrations/easy_context/zigzag_ring_attn/monkey_patch.py @@ -1,10 +1,14 @@ -import warnings -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.utils.checkpoint -import transformers from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, +) def new_flash_attn_forward( @@ -18,6 +22,10 @@ def new_flash_attn_forward( softmax_scale=None, use_sliding_windows=False, ): + assert ( + self.config._attn_implementation == "flash_attention_2" + ), "Only Flash Attention is supported." + if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: @@ -48,26 +56,19 @@ def new_decoder_forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - assert isinstance( - self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2 - ) or isinstance( + assert isinstance(self.self_attn, LlamaAttention) or isinstance( self.self_attn, - transformers.models.mistral.modeling_mistral.MistralFlashAttention2, - ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." - - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + MistralAttention, + ), "Llama and Mistral attention only are supported." residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -75,6 +76,7 @@ def new_decoder_forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -86,29 +88,19 @@ def new_decoder_forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs def apply_zigzag_ring_attn_monkey_patch_llama(): - transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward - ) - transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = ( - new_decoder_forward - ) + # LlamaAttention._flash_attention_forward = new_flash_attn_forward + ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward}) + LlamaDecoderLayer.forward = new_decoder_forward def apply_zigzag_ring_attn_monkey_patch_mistral(): - transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( - new_flash_attn_forward - ) - transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = ( - new_decoder_forward - ) + # MistralAttention._flash_attention_forward = new_flash_attn_forward + ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward}) + MistralDecoderLayer.forward = new_decoder_forward diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0682b81b5..173a1b99f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,6 +11,7 @@ from functools import cached_property from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 import addict +from axolotl.utils.ring_attn import register_ring_attn import bitsandbytes as bnb import torch import transformers @@ -548,18 +549,17 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_size > 1: - from axolotl.integrations.easy_context import ( - apply_seq_parallel_monkey_patch, - ) + # from axolotl.integrations.easy_context import ( + # apply_seq_parallel_monkey_patch, + # ) - method = self.cfg.sequence_parallel_mode - model_type = self.cfg.model_type + # method = self.cfg.sequence_parallel_mode + # model_type = self.cfg.model_config_type - # Apply the monkey patch - apply_seq_parallel_monkey_patch(method, model_type) - - # Ensure flash attention is enabled when loading the model - self.cfg.attn_implementation = "flash_attention_2" + # # Apply the monkey patch + # apply_seq_parallel_monkey_patch(method, model_type) + + register_ring_attn(self.cfg.sequence_parallel_size) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/src/axolotl/utils/ring_attn.py b/src/axolotl/utils/ring_attn.py new file mode 100644 index 000000000..09a94b9fa --- /dev/null +++ b/src/axolotl/utils/ring_attn.py @@ -0,0 +1,40 @@ +import torch.distributed as dist +from ring_flash_attn import substitute_hf_flash_attn + +RING_ATTN_GROUP = None + + +def get_ring_attn_group(): + return RING_ATTN_GROUP + + +def set_ring_attn_group(ring_attn_group): + global RING_ATTN_GROUP + RING_ATTN_GROUP = ring_attn_group + + +def register_ring_attn(sequence_parallel_size): + """ + Create ring attention group and substitute flash attention with ring flash + attention. + """ + if sequence_parallel_size == 1: + return + + world_size = dist.get_world_size() + assert world_size % sequence_parallel_size == 0, \ + f"sequence_parallel_size ({sequence_parallel_size}) " \ + f"must evenly divide world_size ({world_size})" + + for i in range(world_size // sequence_parallel_size): + ring_attn_ranks = list( + range( + i * sequence_parallel_size, + (i + 1) * sequence_parallel_size, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + if dist.get_rank() in ring_attn_ranks: + set_ring_attn_group(group) + + substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 090e677a6..d9fbbbc40 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from functools import partial from typing import List, Optional +from axolotl.integrations.easy_context import prepare_seq_parallel_inputs import numpy as np import torch import torch.cuda @@ -346,7 +347,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing: + elif cfg.sample_packing or cfg.sequence_parallel_size > 1: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" @@ -356,7 +357,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.eval_sample_packing is not False: + if cfg.sequence_parallel_size > 1: + train_dataset.map( + prepare_seq_parallel_inputs, + "dist_flash_attn", + lambda batch: batch["input_ids"], + lambda batch: batch["position_ids"], + lambda batch: batch["target_ids"], + accelerator.process_index, + accelerator.num_processes, + accelerator.device, + ) + if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids,