Compare commits
2 Commits
feat/wizar
...
seq-parall
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee489d16bf | ||
|
|
d88e071120 |
@@ -59,6 +59,7 @@ from axolotl.core.training_args import (
|
|||||||
AxolotlTrainingArguments,
|
AxolotlTrainingArguments,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -746,6 +747,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
if self.cfg.sp_ulysses_degree:
|
||||||
|
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
||||||
|
USPRingAttnType.ZIGZAG,
|
||||||
|
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|||||||
@@ -206,6 +206,16 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sp_ulysses_degree: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
||||||
|
)
|
||||||
|
|
||||||
|
sp_ring_degree: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import get_world_size, get_rank
|
||||||
|
|
||||||
|
|
||||||
|
class USPRingAttnType(Enum):
|
||||||
|
BASIC = "basic"
|
||||||
|
ZIGZAG = "zigzag"
|
||||||
|
STRIPE = "stripe"
|
||||||
|
|
||||||
|
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||||
|
|
||||||
|
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||||
|
|
||||||
|
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
||||||
|
fn = EXTRACT_FUNC_DICT["basic"]
|
||||||
|
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
||||||
|
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
||||||
|
|
||||||
|
# map bad key upstream
|
||||||
|
elif ring_impl_type == USPRingAttnType.STRIPE:
|
||||||
|
fn = EXTRACT_FUNC_DICT["strip"]
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
rd = world_size // sp_ulysses_degree
|
||||||
|
|
||||||
|
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
||||||
|
|
||||||
|
def set_usp_parallel_group(sp_ulysses_degree):
|
||||||
|
"""
|
||||||
|
setup distributed parallel group for USP attention
|
||||||
|
make sure this gets called before building any USP attention modules
|
||||||
|
:param sp_ulysses_degree:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
rank = get_rank()
|
||||||
|
sp_ring_degree = world_size // sp_ulysses_degree
|
||||||
|
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
||||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from yunchang import LongContextAttention
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||||
|
|
||||||
|
|
||||||
|
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||||
|
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||||
|
|
||||||
|
def flash_attention_forward(
|
||||||
|
module: torch.nn.Module, # pylint: disable=unused-argument
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[torch.Tensor, None]:
|
||||||
|
attn_output = usp_attn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=scaling,
|
||||||
|
causal=True,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
return flash_attention_forward
|
||||||
@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -53,6 +53,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sp_extract_fn: Optional[Callable] = None
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -121,6 +122,10 @@ class DataCollatorForSeq2Seq:
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def seq_parallel_split(self, features):
|
||||||
|
if self.sp_extract_fn:
|
||||||
|
pass
|
||||||
|
return features
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
|||||||
@@ -832,6 +832,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: Optional[bool] = None
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
sp_ulysses_degree: Optional[int] = None
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
|
|||||||
@@ -86,6 +86,12 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_distributed():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_only():
|
def zero_only():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sp_ulysses_degree:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
|
|||||||
Reference in New Issue
Block a user