From ee489d16bf3de1e324a8988c21b77daca5c7a306 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 27 Feb 2025 11:42:46 -0500 Subject: [PATCH] wip --- src/axolotl/core/trainer_builder.py | 6 +++ src/axolotl/core/training_args.py | 10 +++++ .../attention/sequence_parallel/__init__.py | 37 ++++++++++++++++--- .../attention/sequence_parallel/usp.py | 10 ++--- src/axolotl/utils/collators/batching.py | 7 +++- .../config/models/input/v0_4_1/__init__.py | 2 + src/axolotl/utils/distributed.py | 6 +++ src/axolotl/utils/trainer.py | 2 +- 8 files changed, 68 insertions(+), 12 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a2..290db063c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -59,6 +59,7 @@ from axolotl.core.training_args import ( AxolotlTrainingArguments, ) from axolotl.integrations.base import PluginManager +from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.utils import is_comet_available, is_mlflow_available @@ -746,6 +747,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html data_collator_kwargs["pad_to_multiple_of"] = 64 + if self.cfg.sp_ulysses_degree: + data_collator_kwargs["sp_extract_fn"] = get_extract_fn( + USPRingAttnType.ZIGZAG, + sp_ulysses_degree=self.cfg.sp_ulysses_degree + ) if self.cfg.reward_model: data_collator_kwargs["max_length"] = self.cfg.sequence_len diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 7cace7643..f19293c6e 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -206,6 +206,16 @@ class AxolotlTrainingMixins: }, ) + sp_ulysses_degree: Optional[int] = field( + default=None, + metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"}, + ) + + sp_ring_degree: Optional[int] = field( + default=None, + metadata={"help": "Ring attention parallelism for sequence parallel long context attn"}, + ) + @dataclass class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): diff --git a/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py b/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py index 3ff99c775..23010828a 100644 --- a/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py +++ b/src/axolotl/monkeypatch/attention/sequence_parallel/__init__.py @@ -1,18 +1,45 @@ from enum import Enum +from functools import partial from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT + +from axolotl.utils.distributed import get_world_size, get_rank + class USPRingAttnType(Enum): BASIC = "basic" ZIGZAG = "zigzag" - STRIDE = "stride" - -def patch_seq_parallel(): - ALL_ATTENTION_FUNCTIONS["flash_attention_2"] - pass + STRIPE = "stripe" def apply_usp_attn_patch(ring_impl_type: USPRingAttnType): from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward fa_forward = build_usp_fa_forward(ring_impl_type) ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward + +def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int): + fn = EXTRACT_FUNC_DICT["basic"] + if ring_impl_type.value in EXTRACT_FUNC_DICT.keys(): + fn = EXTRACT_FUNC_DICT[ring_impl_type.value] + + # map bad key upstream + elif ring_impl_type == USPRingAttnType.STRIPE: + fn = EXTRACT_FUNC_DICT["strip"] + + world_size = get_world_size() + rd = world_size // sp_ulysses_degree + + return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree) + +def set_usp_parallel_group(sp_ulysses_degree): + """ + setup distributed parallel group for USP attention + make sure this gets called before building any USP attention modules + :param sp_ulysses_degree: + :return: + """ + world_size = get_world_size() + rank = get_rank() + sp_ring_degree = world_size // sp_ulysses_degree + set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size) diff --git a/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py b/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py index 34ee792c5..8ed2cb89a 100644 --- a/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py +++ b/src/axolotl/monkeypatch/attention/sequence_parallel/usp.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Optional, Tuple, Callable import torch -from yunchang import LongContextAttention, set_seq_parallel_pg +from yunchang import LongContextAttention from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType @@ -11,16 +11,16 @@ def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable: usp_attn = LongContextAttention(ring_impl_type.value) def flash_attention_forward( - module: torch.nn.Module, + module: torch.nn.Module, # pylint: disable=unused-argument query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument dropout: float = 0.0, scaling: Optional[float] = None, - sliding_window: Optional[int] = None, + sliding_window: Optional[int] = None, # pylint: disable=unused-argument softcap: Optional[float] = None, - **kwargs, + **kwargs, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, None]: attn_output = usp_attn( query, diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7cf771421..bcee98af6 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences """ from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Callable import numpy as np from transformers import PreTrainedTokenizerBase @@ -53,6 +53,7 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" + sp_extract_fn: Optional[Callable] = None def __call__(self, features, return_tensors=None): labels = None @@ -121,6 +122,10 @@ class DataCollatorForSeq2Seq: return features + def seq_parallel_split(self, features): + if self.sp_extract_fn: + pass + return features @dataclass class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1810413be..184a18e95 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -832,6 +832,8 @@ class AxolotlInputConfig( eager_attention: Optional[bool] = None + sp_ulysses_degree: Optional[int] = None + unsloth_cross_entropy_loss: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 81a928b6e..ff05d5065 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -86,6 +86,12 @@ def get_world_size(): return int(os.getenv("WORLD_SIZE", "1")) +def get_rank(): + if not is_distributed(): + return 0 + return dist.get_rank() + + @contextmanager def zero_only(): """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 8553339b9..2cc0c50f4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing: + elif cfg.sample_packing or cfg.sp_ulysses_degree: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"