From f487910444b2181160c92bbd7a7aad15feaffdb0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 5 Mar 2025 15:18:53 +0000 Subject: [PATCH] removing unused code --- src/axolotl/monkeypatch/sequence_parallel.py | 123 ------------------- src/axolotl/utils/schemas/config.py | 2 + 2 files changed, 2 insertions(+), 123 deletions(-) delete mode 100644 src/axolotl/monkeypatch/sequence_parallel.py diff --git a/src/axolotl/monkeypatch/sequence_parallel.py b/src/axolotl/monkeypatch/sequence_parallel.py deleted file mode 100644 index b3caefb14..000000000 --- a/src/axolotl/monkeypatch/sequence_parallel.py +++ /dev/null @@ -1,123 +0,0 @@ -""" -Utilities for sequence parallelism implementation. - -Modified from: -https://github.com/Qihoo360/360-LLaMA-Factory/blob/f295a5760cceebe069fb5b975813d2c945598acb/src/llamafactory/model/model_utils/sequence_parallel.py -""" - -from functools import partial - -import torch.distributed as dist -import transformers -import transformers.modeling_attn_mask_utils -from ring_flash_attn import ( - ring_flash_attn_func, - stripe_flash_attn_func, - zigzag_ring_flash_attn_func, -) - - -def ring_flash_attn_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=0, - sliding_window=None, - is_causal=True, - group=None, - **kwargs, -): - attn_output = ring_flash_attn_func( - query_states, key_states, value_states, dropout, causal=is_causal, group=group - ) - - return attn_output - - -def zigzag_flash_attn_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=0, - sliding_window=None, - is_causal=True, - group=None, - **kwargs, -): - attn_output = zigzag_ring_flash_attn_func( - query_states, key_states, value_states, dropout, causal=is_causal, group=group - ) - - return attn_output - - -def stripe_flash_attn_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=0, - sliding_window=None, - is_causal=True, - group=None, - **kwargs, -): - attn_output = stripe_flash_attn_func( - query_states, key_states, value_states, dropout, causal=is_causal, group=group - ) - - return attn_output - - -def init_sp_group(sp_size): - assert dist.is_initialized() - world_size = dist.get_world_size() - assert ( - world_size % sp_size == 0 - ), "Total number of GPUs must be a multiple of sequence_parallel_size." - - sp_group_num = world_size // sp_size - sp_ranks_list = [ - list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num) - ] - - sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list] - - global_rank_this = dist.get_rank() - sp_idx = global_rank_this // sp_size - return sp_groups[sp_idx] - - -def apply_sequence_parallel(cfg): - if cfg.sequence_parallel_size == 1: - return None # no sequence parallelism - - # init sequence-parallel groups here - group_this = init_sp_group(cfg.sequence_parallel_size) - - if cfg.sequence_parallel_mode == "ring": - new_flash_attention_forward = partial(ring_flash_attn_forward, group=group_this) - elif cfg.sequence_parallel_mode == "zigzag-ring": - new_flash_attention_forward = partial( - zigzag_flash_attn_forward, group=group_this - ) - elif cfg.sequence_parallel_mode == "stripe": - new_flash_attention_forward = partial( - stripe_flash_attn_forward, group=group_this - ) - else: - raise NotImplementedError( - "Other sequence parallel modes are to be implemented." - ) - - # monkey patching - transformers.modeling_flash_attention_utils._flash_attention_forward = ( - new_flash_attention_forward - ) - - return group_this diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7676a50a8..f57586a82 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -245,6 +245,8 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) + sequence_parallel_size: int | None = 1 + special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None added_tokens_overrides: dict[int, str] | None = None