removing unused code

This commit is contained in:
Dan Saunders
2025-03-05 15:18:53 +00:00
parent c5071dfd8a
commit f487910444
2 changed files with 2 additions and 123 deletions

View File

@@ -1,123 +0,0 @@
"""
Utilities for sequence parallelism implementation.
Modified from:
https://github.com/Qihoo360/360-LLaMA-Factory/blob/f295a5760cceebe069fb5b975813d2c945598acb/src/llamafactory/model/model_utils/sequence_parallel.py
"""
from functools import partial
import torch.distributed as dist
import transformers
import transformers.modeling_attn_mask_utils
from ring_flash_attn import (
ring_flash_attn_func,
stripe_flash_attn_func,
zigzag_ring_flash_attn_func,
)
def ring_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = ring_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def zigzag_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = zigzag_ring_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def stripe_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0,
sliding_window=None,
is_causal=True,
group=None,
**kwargs,
):
attn_output = stripe_flash_attn_func(
query_states, key_states, value_states, dropout, causal=is_causal, group=group
)
return attn_output
def init_sp_group(sp_size):
assert dist.is_initialized()
world_size = dist.get_world_size()
assert (
world_size % sp_size == 0
), "Total number of GPUs must be a multiple of sequence_parallel_size."
sp_group_num = world_size // sp_size
sp_ranks_list = [
list(range(i * sp_size, i * sp_size + sp_size)) for i in range(sp_group_num)
]
sp_groups = [dist.new_group(sp_ranks_this) for sp_ranks_this in sp_ranks_list]
global_rank_this = dist.get_rank()
sp_idx = global_rank_this // sp_size
return sp_groups[sp_idx]
def apply_sequence_parallel(cfg):
if cfg.sequence_parallel_size == 1:
return None # no sequence parallelism
# init sequence-parallel groups here
group_this = init_sp_group(cfg.sequence_parallel_size)
if cfg.sequence_parallel_mode == "ring":
new_flash_attention_forward = partial(ring_flash_attn_forward, group=group_this)
elif cfg.sequence_parallel_mode == "zigzag-ring":
new_flash_attention_forward = partial(
zigzag_flash_attn_forward, group=group_this
)
elif cfg.sequence_parallel_mode == "stripe":
new_flash_attention_forward = partial(
stripe_flash_attn_forward, group=group_this
)
else:
raise NotImplementedError(
"Other sequence parallel modes are to be implemented."
)
# monkey patching
transformers.modeling_flash_attention_utils._flash_attention_forward = (
new_flash_attention_forward
)
return group_this

View File

@@ -245,6 +245,8 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_size: int | None = 1
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
added_tokens_overrides: dict[int, str] | None = None