Compare commits
3 Commits
flx_attn_s
...
seq-parall
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee489d16bf | ||
|
|
d88e071120 | ||
|
|
a4170030ab |
4
.github/workflows/multi-gpu-e2e.yml
vendored
4
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,6 +4,10 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/*.py'
|
- 'tests/e2e/multigpu/*.py'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- 'setup.py'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- '.github/workflows/multi-gpu-e2e.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|||||||
@@ -37,15 +37,11 @@ temp_dir = tempfile.mkdtemp()
|
|||||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||||
f.write(dockerfile_contents)
|
f.write(dockerfile_contents)
|
||||||
|
|
||||||
cicd_image = (
|
cicd_image = Image.from_dockerfile(
|
||||||
Image.from_dockerfile(
|
pathlib.Path(temp_dir) / "Dockerfile",
|
||||||
pathlib.Path(temp_dir) / "Dockerfile",
|
force_build=True,
|
||||||
force_build=True,
|
gpu="A10G",
|
||||||
gpu="A10G",
|
).env(df_args)
|
||||||
)
|
|
||||||
.env(df_args)
|
|
||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
|
||||||
)
|
|
||||||
|
|
||||||
app = App("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|||||||
@@ -123,8 +123,6 @@ class ModalCloud(Cloud):
|
|||||||
if env := self.get_env():
|
if env := self.get_env():
|
||||||
image = image.env(env)
|
image = image.env(env)
|
||||||
|
|
||||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def get_secrets(self):
|
def get_secrets(self):
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ from axolotl.core.training_args import (
|
|||||||
AxolotlTrainingArguments,
|
AxolotlTrainingArguments,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -746,6 +747,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
if self.cfg.sp_ulysses_degree:
|
||||||
|
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
||||||
|
USPRingAttnType.ZIGZAG,
|
||||||
|
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
@@ -831,9 +837,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.flex_attention is True:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
@@ -206,6 +206,16 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sp_ulysses_degree: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
||||||
|
)
|
||||||
|
|
||||||
|
sp_ring_degree: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import get_world_size, get_rank
|
||||||
|
|
||||||
|
|
||||||
|
class USPRingAttnType(Enum):
|
||||||
|
BASIC = "basic"
|
||||||
|
ZIGZAG = "zigzag"
|
||||||
|
STRIPE = "stripe"
|
||||||
|
|
||||||
|
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||||
|
|
||||||
|
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||||
|
|
||||||
|
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
||||||
|
fn = EXTRACT_FUNC_DICT["basic"]
|
||||||
|
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
||||||
|
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
||||||
|
|
||||||
|
# map bad key upstream
|
||||||
|
elif ring_impl_type == USPRingAttnType.STRIPE:
|
||||||
|
fn = EXTRACT_FUNC_DICT["strip"]
|
||||||
|
|
||||||
|
world_size = get_world_size()
|
||||||
|
rd = world_size // sp_ulysses_degree
|
||||||
|
|
||||||
|
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
||||||
|
|
||||||
|
def set_usp_parallel_group(sp_ulysses_degree):
|
||||||
|
"""
|
||||||
|
setup distributed parallel group for USP attention
|
||||||
|
make sure this gets called before building any USP attention modules
|
||||||
|
:param sp_ulysses_degree:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
rank = get_rank()
|
||||||
|
sp_ring_degree = world_size // sp_ulysses_degree
|
||||||
|
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
||||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Tuple, Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from yunchang import LongContextAttention
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||||
|
|
||||||
|
|
||||||
|
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||||
|
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||||
|
|
||||||
|
def flash_attention_forward(
|
||||||
|
module: torch.nn.Module, # pylint: disable=unused-argument
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: Optional[float] = None,
|
||||||
|
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
) -> Tuple[torch.Tensor, None]:
|
||||||
|
attn_output = usp_attn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout_p=dropout,
|
||||||
|
softmax_scale=scaling,
|
||||||
|
causal=True,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
return flash_attention_forward
|
||||||
@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
def get_packed_mask_from_pos_ids(position_ids):
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for i, row in enumerate(position_ids):
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
|
|
||||||
for i, seq_len in enumerate(seq_lengths):
|
|
||||||
start_id = start_indices[i]
|
|
||||||
doc_mask[start_id : start_id + seq_len] = (
|
|
||||||
(i+1) * doc_mask[start_id : start_id + seq_len]
|
|
||||||
)
|
|
||||||
if padding_length:
|
|
||||||
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
|
||||||
|
|
||||||
results.append(doc_mask)
|
|
||||||
|
|
||||||
return torch.stack(results)
|
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_from_pos_ids(position_ids):
|
|
||||||
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
max_seq_len = position_ids.shape[1]
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
totalseqlens = []
|
|
||||||
|
|
||||||
for row in position_ids:
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
if padding_length:
|
|
||||||
seq_lengths = torch.cat(
|
|
||||||
[
|
|
||||||
seq_lengths,
|
|
||||||
torch.tensor(
|
|
||||||
[len(row) - torch.sum(seq_lengths)],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(seq_lengths)
|
|
||||||
totalseqlens.append(len(adjusted_row))
|
|
||||||
|
|
||||||
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
|
|||||||
when they attend to each other within that sequence.
|
when they attend to each other within that sequence.
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
if len(mask.size()) == 4:
|
|
||||||
return mask
|
|
||||||
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DataCollator for axolotl to pad labels and position_ids for packed sequences
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union, Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@@ -53,6 +53,7 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sp_extract_fn: Optional[Callable] = None
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -121,6 +122,10 @@ class DataCollatorForSeq2Seq:
|
|||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def seq_parallel_split(self, features):
|
||||||
|
if self.sp_extract_fn:
|
||||||
|
pass
|
||||||
|
return features
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
|||||||
@@ -823,7 +823,6 @@ class AxolotlInputConfig(
|
|||||||
xformers_attention: Optional[bool] = None
|
xformers_attention: Optional[bool] = None
|
||||||
sdp_attention: Optional[bool] = None
|
sdp_attention: Optional[bool] = None
|
||||||
s2_attention: Optional[bool] = None
|
s2_attention: Optional[bool] = None
|
||||||
flex_attention: Optional[bool] = None
|
|
||||||
flash_attention: Optional[bool] = None
|
flash_attention: Optional[bool] = None
|
||||||
flash_attn_cross_entropy: Optional[bool] = None
|
flash_attn_cross_entropy: Optional[bool] = None
|
||||||
flash_attn_rms_norm: Optional[bool] = None
|
flash_attn_rms_norm: Optional[bool] = None
|
||||||
@@ -833,6 +832,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: Optional[bool] = None
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
sp_ulysses_degree: Optional[int] = None
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
@@ -1790,26 +1791,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_flex_torch_version(cls, data):
|
|
||||||
if (data.get("flex_attention") is not None) and (
|
|
||||||
data.get("flex_attention") is True
|
|
||||||
):
|
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
|
||||||
torch_version = env_capabilities.get("torch_version")
|
|
||||||
|
|
||||||
if torch_version is None:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
|
||||||
|
|
||||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
|
||||||
raise ValueError(
|
|
||||||
"Flex attention is not supported on torch version < 2.5.1"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_auto(cls, data):
|
def check_torch_compile_auto(cls, data):
|
||||||
|
|||||||
@@ -86,6 +86,12 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_distributed():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_only():
|
def zero_only():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if "auto_map" in self.model_config:
|
if "auto_map" in self.model_config:
|
||||||
@@ -707,13 +707,7 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
|
if self.cfg.flash_attention:
|
||||||
if self.cfg.flex_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flex_attention"
|
|
||||||
)
|
|
||||||
elif self.cfg.flash_attention:
|
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
@@ -1119,7 +1113,7 @@ class ModelLoader:
|
|||||||
should_convert = (
|
should_convert = (
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
|
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sp_ulysses_degree:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
|
|||||||
Reference in New Issue
Block a user