support for true batches with multipack (#1230)
* support for true batches with multipack * patch the map dataset fetcher to handle batches with packed indexes * patch 4d mask creation for sdp attention * better handling for BetterTransformer * patch general case for 4d mask * setup forward patch. WIP * fix patch file * support for multipack w/o flash attention for llama * cleanup * add warning about bf16 vs fp16 for multipack with sdpa * bugfixes * add 4d multipack tests, refactor patches * update tests and add warnings * fix e2e file check * skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -98,6 +98,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
@@ -229,11 +233,19 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -243,11 +255,19 @@ class AxolotlTrainer(Trainer):
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -860,6 +880,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.flash_attention is not True
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
@@ -964,6 +987,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
|
||||
0
src/axolotl/monkeypatch/data/__init__.py
Normal file
0
src/axolotl/monkeypatch/data/__init__.py
Normal file
46
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
Normal file
46
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
if (
|
||||
hasattr(self.dataset, "__getitems__")
|
||||
and self.dataset.__getitems__
|
||||
):
|
||||
data[i] = self.dataset.__getitems__(possibly_batched_index_)
|
||||
else:
|
||||
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
else:
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
else:
|
||||
data = self.dataset[possibly_batched_index]
|
||||
return self.collate_fn(data)
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
@@ -1,142 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# sdp-attn start
|
||||
#
|
||||
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# sdp-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@@ -5,38 +5,11 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
|
||||
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
import transformers.modeling_attn_mask_utils
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -1,8 +1,15 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -89,7 +96,6 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
@@ -135,7 +141,18 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
# Find the maximum value across all tensors
|
||||
max_value = max(t.max() for t in results)
|
||||
|
||||
# Find the length of the longest tensor
|
||||
max_length = max(t.size(0) for t in results)
|
||||
|
||||
# Pad each tensor to the same length and collect them in a list
|
||||
padded_results = [
|
||||
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
|
||||
]
|
||||
|
||||
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
@@ -149,3 +166,62 @@ def set_module_name(model, name, value):
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -24,6 +23,11 @@ from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -124,7 +128,7 @@ def train(
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
@@ -149,7 +153,10 @@ def train(
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
@@ -195,7 +202,7 @@ def train(
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
@@ -132,24 +132,26 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -159,24 +161,26 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features)
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -202,6 +202,20 @@ def validate_config(cfg):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
@@ -350,17 +364,24 @@ def validate_config(cfg):
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
|
||||
@@ -834,7 +834,7 @@ def encode_packed_pretraining(
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
batch_size=1,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
@@ -842,15 +842,16 @@ def encode_packed_pretraining(
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
for data in sampler:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
for batch in sampler:
|
||||
for data in batch:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
@@ -324,13 +323,13 @@ def load_model(
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
||||
hijack_llama_sdp_attention,
|
||||
elif cfg.sample_packing:
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
)
|
||||
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
elif cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
@@ -506,6 +505,12 @@ def load_model(
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif cfg.sdp_attention:
|
||||
model_kwargs["attn_implementation"] = "sdpa"
|
||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||
elif cfg.eager_attention:
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
try:
|
||||
if (
|
||||
@@ -749,6 +754,8 @@ def load_model(
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
|
||||
@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = None
|
||||
self.batch_size = batch_size
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
@@ -147,7 +147,13 @@ class MultipackBatchSampler(BatchSampler):
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
batches = [
|
||||
[
|
||||
[indices[b_idx] for b_idx in batch]
|
||||
for batch in batches[i : i + self.batch_size]
|
||||
]
|
||||
for i in range(0, len(batches), self.batch_size)
|
||||
]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
@@ -189,7 +195,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.batch_max_len
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
)
|
||||
- 1
|
||||
),
|
||||
|
||||
@@ -237,11 +237,17 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention:
|
||||
batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=cfg.micro_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
@@ -249,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
data_loader_len = len(data_loader) // batch_size
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
|
||||
Reference in New Issue
Block a user