diff --git a/README.md b/README.md
index 80c0fedff..a7be20dcd 100644
--- a/README.md
+++ b/README.md
@@ -37,6 +37,9 @@ Features:
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
+- Advanced Topics
+ - [Multipack](./docs/multipack.md)
+ - [RLHF & DPO](./docs/rlhf.md)
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
diff --git a/docs/images/4d-mask.png b/docs/images/4d-mask.png
new file mode 100644
index 000000000..181e693cf
Binary files /dev/null and b/docs/images/4d-mask.png differ
diff --git a/docs/multipack.md b/docs/multipack.md
index 2a55148b2..bee13b62c 100644
--- a/docs/multipack.md
+++ b/docs/multipack.md
@@ -1,4 +1,11 @@
-# Multipack
+# Multipack (Sample Packing)
+
+## Visualization of Multipack with Flash Attention
+
+Because Flash Attention simply drops the attention mask, we do not need to
+construct a 4d attention mask. We only need to concatenate the sequences into
+a single batch and let flash attention know where each new sequence begins.
+
4k context, bsz =4,
each character represents 256 tokens
@@ -49,3 +56,18 @@ w packing ( note it's the same effective number of tokens per step, but a true b
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```
+
+cu_seqlens:
+[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
+
+
+## Multipack without Flash Attention
+
+Multipack can still be achieved without Flash attention, but with lower packing
+efficiency as we are not able to join multiple batches into a single batch due to
+context length limits without flash attention. We can use either Pytorch's Scaled
+Dot Product Attention implementation or native Pytorch attention implementation
+along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
+to pack sequences together and avoid cross attention.
+
+
diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py
index c8aea4a71..636a23ba5 100644
--- a/src/axolotl/common/cli.py
+++ b/src/axolotl/common/cli.py
@@ -6,6 +6,7 @@ import logging
from dataclasses import dataclass, field
from typing import Optional
+import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py
index 96054dc50..5e52dafab 100644
--- a/src/axolotl/core/trainer_builder.py
+++ b/src/axolotl/core/trainer_builder.py
@@ -98,6 +98,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
+ multipack_real_batches: bool = field(
+ default=False,
+ metadata={"help": "Use real batches for efficient training."},
+ )
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
@@ -229,11 +233,19 @@ class AxolotlTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
+ if self.args.multipack_real_batches:
+ batch_size = self.args.per_device_train_batch_size
+ batch_max_len = self.args.max_seq_length
+ else:
+ batch_size = 1
+ batch_max_len = (
+ self.args.per_device_train_batch_size * self.args.max_seq_length
+ )
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
- self.args.train_batch_size,
+ batch_size=batch_size,
drop_last=True,
- batch_max_len=self._train_batch_size * self.args.max_seq_length,
+ batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
@@ -243,11 +255,19 @@ class AxolotlTrainer(Trainer):
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
+ if self.args.multipack_real_batches:
+ batch_size = self.args.per_device_eval_batch_size
+ batch_max_len = self.args.max_seq_length
+ else:
+ batch_size = 1
+ batch_max_len = (
+ self.args.per_device_eval_batch_size * self.args.max_seq_length
+ )
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
- self.args.per_device_eval_batch_size,
+ batch_size=batch_size,
drop_last=True,
- batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
+ batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
@@ -860,6 +880,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
+ training_arguments_kwargs["multipack_real_batches"] = (
+ self.cfg.flash_attention is not True
+ )
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
@@ -964,6 +987,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
+ elif (
+ self.cfg.model_config_type in ["llama"]
+ and self.cfg.flash_attention is not True
+ ):
+ collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
diff --git a/src/axolotl/monkeypatch/data/__init__.py b/src/axolotl/monkeypatch/data/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
new file mode 100644
index 000000000..2e9364e3a
--- /dev/null
+++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
@@ -0,0 +1,46 @@
+"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
+# pylint: disable=protected-access
+
+import torch
+from torch.utils.data._utils.fetch import _BaseDatasetFetcher
+from torch.utils.data._utils.worker import _worker_loop
+
+
+class _MapDatasetFetcher(_BaseDatasetFetcher):
+ def fetch(self, possibly_batched_index):
+ if isinstance(possibly_batched_index[0], list):
+ data = [None for i in possibly_batched_index]
+ for i, possibly_batched_index_ in enumerate(possibly_batched_index):
+ if self.auto_collation:
+ if (
+ hasattr(self.dataset, "__getitems__")
+ and self.dataset.__getitems__
+ ):
+ data[i] = self.dataset.__getitems__(possibly_batched_index_)
+ else:
+ data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
+ else:
+ data[i] = self.dataset[possibly_batched_index_]
+ else:
+ if self.auto_collation:
+ if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
+ data = self.dataset.__getitems__(possibly_batched_index)
+ else:
+ data = [self.dataset[idx] for idx in possibly_batched_index]
+ else:
+ data = self.dataset[possibly_batched_index]
+ return self.collate_fn(data)
+
+
+def patch_fetchers():
+ torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
+ torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
+
+
+def patched_worker_loop(*args, **kwargs):
+ patch_fetchers()
+ return _worker_loop(*args, **kwargs)
+
+
+torch.utils.data._utils.worker._worker_loop = patched_worker_loop
+patch_fetchers()
diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py b/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
deleted file mode 100644
index cfed8cb17..000000000
--- a/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
+++ /dev/null
@@ -1,142 +0,0 @@
-"""
-Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
-"""
-
-import warnings
-from typing import Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-import transformers.models.llama.modeling_llama
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
-
-
-def hijack_llama_sdp_attention():
- transformers.models.llama.modeling_llama.LlamaAttention.forward = (
- sdp_attention_forward
- )
-
-
-def sdp_attention_forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
- **kwargs, # pylint: disable=unused-argument
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # pylint: disable=duplicate-code
- bsz, q_len, _ = hidden_states.size()
-
- if not hasattr(self, "pretraining_tp"):
- self.pretraining_tp = 1
-
- if self.pretraining_tp > 1:
- key_value_slicing = (
- self.num_key_value_heads * self.head_dim
- ) // self.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
- )
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
- key_states = key_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
- value_states = value_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
- # [bsz, q_len, nh, hd]
- # [bsz, nh, q_len, hd]
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids
- )
- # [bsz, nh, t, hd]
-
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states) if use_cache else None
-
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
- if output_attentions:
- warnings.warn(
- "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
- )
-
- #
- # sdp-attn start
- #
-
- with torch.backends.cuda.sdp_kernel():
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- is_causal=False,
- )
-
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
- #
- # sdp-attn end
- #
-
- if self.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.pretraining_tp, dim=1
- )
- attn_output = sum(
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.pretraining_tp)
- )
- else:
- attn_output = self.o_proj(attn_output)
-
- return attn_output, None, past_key_value
diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py
index d69433baa..5738bb543 100644
--- a/src/axolotl/monkeypatch/llama_expand_mask.py
+++ b/src/axolotl/monkeypatch/llama_expand_mask.py
@@ -5,38 +5,11 @@ from typing import Optional
import torch
+from axolotl.monkeypatch.utils import mask_2d_to_4d
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- This expansion handles packed sequences so that sequences share the same attention mask integer value
- when they attend to each other within that sequence.
- This expansion transforms the mask to lower triangular form to prevent future peeking.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- mask = mask.unsqueeze(1).unsqueeze(2)
- mask = mask.expand(bsz, 1, tgt_len, src_len)
-
- # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
- binary_mask = torch.where(
- mask != 0,
- torch.tensor(1).to(dtype),
- torch.tensor(0).to(dtype),
- )
-
- # Create a block-diagonal mask.
- # we multiply by the binary mask so that 0's in the original mask are correctly excluded
- zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
-
- # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
- lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
- mask.device
- )
-
- # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
- masked_zero_one_mask = zero_one_mask * lower_triangular_ones
+ masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py
new file mode 100644
index 000000000..540c5577a
--- /dev/null
+++ b/src/axolotl/monkeypatch/llama_patch_multipack.py
@@ -0,0 +1,26 @@
+"""
+Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
+"""
+
+from axolotl.monkeypatch.utils import (
+ patched_prepare_4d_causal_attention_mask,
+ patched_prepare_4d_causal_attention_mask_for_sdpa,
+)
+
+
+def hijack_llama_prepare_4d_mask():
+ import transformers.modeling_attn_mask_utils
+ import transformers.models.llama.modeling_llama
+
+ transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
+ patched_prepare_4d_causal_attention_mask_for_sdpa
+ )
+ transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
+ patched_prepare_4d_causal_attention_mask_for_sdpa
+ )
+ transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
+ patched_prepare_4d_causal_attention_mask
+ )
+ transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
+ patched_prepare_4d_causal_attention_mask
+ )
diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py
index 193c9d0ec..63141635a 100644
--- a/src/axolotl/monkeypatch/utils.py
+++ b/src/axolotl/monkeypatch/utils.py
@@ -1,8 +1,15 @@
"""
Shared utils for the monkeypatches
"""
+from typing import Optional
+
import torch
import torch.nn.functional as F
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -89,7 +96,6 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
-@torch.jit.script
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -135,7 +141,18 @@ def get_cu_seqlens_from_pos_ids(position_ids):
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
- return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
+ # Find the maximum value across all tensors
+ max_value = max(t.max() for t in results)
+
+ # Find the length of the longest tensor
+ max_length = max(t.size(0) for t in results)
+
+ # Pad each tensor to the same length and collect them in a list
+ padded_results = [
+ F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
+ ]
+
+ return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
@@ -149,3 +166,62 @@ def set_module_name(model, name, value):
child_name = name
setattr(parent, child_name, value)
+
+
+def mask_2d_to_4d(
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
+):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ This expansion handles packed sequences so that sequences share the same attention mask integer value
+ when they attend to each other within that sequence.
+ This expansion transforms the mask to lower triangular form to prevent future peeking.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ mask = mask.expand(bsz, 1, tgt_len, src_len)
+
+ # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
+ binary_mask = torch.where(
+ mask != 0,
+ torch.tensor(1).to(dtype),
+ torch.tensor(0).to(dtype),
+ )
+
+ # Create a block-diagonal mask.
+ # we multiply by the binary mask so that 0's in the original mask are correctly excluded
+ zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
+
+ # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
+ lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
+ mask.device
+ )
+
+ # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
+ masked_zero_one_mask = zero_one_mask * lower_triangular_ones
+
+ return masked_zero_one_mask
+
+
+def patched_prepare_4d_causal_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ *args,
+):
+ dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
+ return _prepare_4d_causal_attention_mask(
+ mask_2d_to_4d(attention_mask, dtype=dtype),
+ *args,
+ )
+
+
+def patched_prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask: Optional[torch.Tensor],
+ *args,
+):
+ dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
+ return _prepare_4d_causal_attention_mask_for_sdpa(
+ mask_2d_to_4d(attention_mask, dtype=dtype),
+ *args,
+ )
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index d0f58bca9..014c281ec 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -11,7 +11,6 @@ import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
-from optimum.bettertransformer import BetterTransformer
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -24,6 +23,11 @@ from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
+try:
+ from optimum.bettertransformer import BetterTransformer
+except ImportError:
+ BetterTransformer = None
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -124,7 +128,7 @@ def train(
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
- if cfg.flash_optimum:
+ if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
@@ -149,7 +153,10 @@ def train(
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
- enable_flash=True, enable_math=True, enable_mem_efficient=True
+ # TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
+ enable_flash=True,
+ enable_math=True,
+ enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
@@ -195,7 +202,7 @@ def train(
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
- if cfg.flash_optimum:
+ if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py
index da1ee392f..8512b9408 100644
--- a/src/axolotl/utils/collators.py
+++ b/src/axolotl/utils/collators.py
@@ -132,24 +132,26 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
- chunked_data = {}
- for feature in features[0].keys():
- if feature == "length":
- continue
- if feature == "attention_mask":
- arrays = [
- (1) * np.array(item[feature])
- for item in features
- if feature in item
- ]
- chunked_data[feature] = np.concatenate(arrays)
- else:
- arrays = [
- np.array(item[feature]) for item in features if feature in item
- ]
- chunked_data[feature] = np.concatenate(arrays)
- features = [chunked_data]
- return super().__call__(features, return_tensors=return_tensors)
+ if not isinstance(features[0], list):
+ features = [features]
+ out_features = [{} for _ in features]
+ for i, features_ in enumerate(features):
+ for feature in features_[0].keys():
+ if feature == "length":
+ continue
+ if feature == "attention_mask":
+ arrays = [
+ (1) * np.array(item[feature])
+ for i, item in enumerate(features_)
+ if feature in item
+ ]
+ out_features[i][feature] = np.concatenate(arrays)
+ else:
+ arrays = [
+ np.array(item[feature]) for item in features_ if feature in item
+ ]
+ out_features[i][feature] = np.concatenate(arrays)
+ return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
@@ -159,24 +161,26 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
- chunked_data = {}
- for feature in features[0].keys():
- if feature == "length":
- continue
- if feature == "attention_mask":
- arrays = [
- (i + 1) * np.array(item[feature])
- for i, item in enumerate(features)
- if feature in item
- ]
- chunked_data[feature] = np.concatenate(arrays)
- else:
- arrays = [
- np.array(item[feature]) for item in features if feature in item
- ]
- chunked_data[feature] = np.concatenate(arrays)
- features = [chunked_data]
- return super().__call__(features, return_tensors=return_tensors)
+ if not isinstance(features[0], list):
+ features = [features]
+ out_features = [{} for _ in features]
+ for i, features_ in enumerate(features):
+ for feature in features_[0].keys():
+ if feature == "length":
+ continue
+ if feature == "attention_mask":
+ arrays = [
+ (i + 1) * np.array(item[feature])
+ for i, item in enumerate(features_)
+ if feature in item
+ ]
+ out_features[i][feature] = np.concatenate(arrays)
+ else:
+ arrays = [
+ np.array(item[feature]) for item in features_ if feature in item
+ ]
+ out_features[i][feature] = np.concatenate(arrays)
+ return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py
index 3ea48f6bb..75b4e5220 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config.py
@@ -202,6 +202,20 @@ def validate_config(cfg):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
+ if (
+ # pylint: disable=too-many-boolean-expressions
+ not (cfg.bf16 or cfg.bfloat16)
+ and (cfg.fp16 or cfg.float16)
+ and not cfg.adapter
+ and not cfg.flash_attention
+ and cfg.sample_packing
+ ):
+ LOG.warning(
+ "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
+ )
+ # ValueError: Attempting to unscale FP16 gradients.
+ # OR
+ # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
if cfg.max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
@@ -350,17 +364,24 @@ def validate_config(cfg):
+ "point to its path, and remove model_revision from the config."
)
- if cfg.sample_packing and cfg.sdp_attention:
- # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
- raise ValueError(
- "sample_packing not compatible with sdp_attention. Use flash_attention"
- )
+ # if cfg.sample_packing and cfg.sdp_attention:
+ # # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
+ # raise ValueError(
+ # "sample_packing not compatible with sdp_attention. Use flash_attention"
+ # )
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
+ if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
+ # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
+ LOG.warning(
+ "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
+ "This may work on H100s."
+ )
+
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index 39766868e..5e6ceb6cb 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -834,7 +834,7 @@ def encode_packed_pretraining(
sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
- batch_size=batch_size,
+ batch_size=1,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=get_dataset_lengths(train_dataset),
@@ -842,15 +842,16 @@ def encode_packed_pretraining(
chunked_data = defaultdict(list)
- for data in sampler:
- features = train_dataset[data]
- features["labels"] = features["input_ids"].copy()
- collated_features = collate_fn(features)
+ for batch in sampler:
+ for data in batch:
+ features = train_dataset[data]
+ features["labels"] = features["input_ids"].copy()
+ collated_features = collate_fn(features)
- for feature in features.keys():
- if feature == "length":
- continue
- chunked_data[feature].append(collated_features[feature].squeeze(0))
+ for feature in features.keys():
+ if feature == "length":
+ continue
+ chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index ad3ac7bc6..224e3a258 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -8,7 +8,6 @@ import addict
import bitsandbytes as bnb
import torch
import transformers
-from optimum.bettertransformer import BetterTransformer
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
@@ -324,13 +323,13 @@ def load_model(
LOG.info("patching with xformers attention")
hijack_llama_attention()
- elif cfg.sdp_attention:
- from axolotl.monkeypatch.llama_attn_hijack_sdp import (
- hijack_llama_sdp_attention,
+ elif cfg.sample_packing:
+ from axolotl.monkeypatch.llama_patch_multipack import (
+ hijack_llama_prepare_4d_mask,
)
- LOG.info("patching with sdp attention")
- hijack_llama_sdp_attention()
+ LOG.info("patching llama _prepare_4d_causal_attention_mask*")
+ hijack_llama_prepare_4d_mask()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
@@ -506,6 +505,12 @@ def load_model(
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
+ elif cfg.sdp_attention:
+ model_kwargs["attn_implementation"] = "sdpa"
+ model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
+ elif cfg.eager_attention:
+ model_kwargs["attn_implementation"] = "eager"
+ model_config._attn_implementation = "eager" # pylint: disable=protected-access
try:
if (
@@ -749,6 +754,8 @@ def load_model(
model.config.use_cache = False
if cfg.flash_optimum:
+ from optimum.bettertransformer import BetterTransformer
+
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py
index 629a1a44c..cf47d9639 100644
--- a/src/axolotl/utils/samplers/multipack.py
+++ b/src/axolotl/utils/samplers/multipack.py
@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
):
super().__init__(sampler, batch_size, drop_last)
- self.batch_size = None
+ self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
@@ -147,7 +147,13 @@ class MultipackBatchSampler(BatchSampler):
n=1,
)
- batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
+ batches = [
+ [
+ [indices[b_idx] for b_idx in batch]
+ for batch in batches[i : i + self.batch_size]
+ ]
+ for i in range(0, len(batches), self.batch_size)
+ ]
# statistics
if set_stats:
@@ -189,7 +195,7 @@ class MultipackBatchSampler(BatchSampler):
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
- // self.batch_max_len
+ // (self.batch_max_len * self.batch_size)
)
- 1
),
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 26197b991..6d4168fbd 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -237,11 +237,17 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
+ if cfg.flash_attention:
+ batch_size = 1
+ batch_max_len = cfg.micro_batch_size * cfg.sequence_len
+ else:
+ batch_size = cfg.micro_batch_size
+ batch_max_len = cfg.sequence_len
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
- batch_size=cfg.micro_batch_size,
+ batch_size=batch_size,
drop_last=True,
- batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
+ batch_max_len=batch_max_len,
lengths=get_dataset_lengths(train_dataset),
)
@@ -249,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
- data_loader_len = len(data_loader)
+ data_loader_len = len(data_loader) // batch_size
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py
new file mode 100644
index 000000000..d74d09723
--- /dev/null
+++ b/tests/e2e/patched/test_4d_multipack_llama.py
@@ -0,0 +1,114 @@
+"""
+E2E tests for multipack fft llama using 4d attention masks
+"""
+
+import logging
+import os
+import unittest
+from pathlib import Path
+
+from axolotl.cli import load_datasets
+from axolotl.common.cli import TrainerCliArgs
+from axolotl.train import train
+from axolotl.utils.config import normalize_config
+from axolotl.utils.dict import DictDefault
+
+from ..utils import require_torch_2_1_1, with_temp_dir
+
+LOG = logging.getLogger("axolotl.tests.e2e")
+os.environ["WANDB_DISABLED"] = "true"
+
+
+class Test4dMultipackLlama(unittest.TestCase):
+ """
+ Test case for Llama models using 4d attention with multipack
+ """
+
+ @require_torch_2_1_1
+ @with_temp_dir
+ def test_sdp_lora_packing(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "flash_attention": False,
+ "sdp_attention": True,
+ "sample_packing": True,
+ "pad_to_sequence_len": True,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "sequence_len": 1024,
+ "val_set_size": 0.1,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ "fp16": True,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
+
+ @with_temp_dir
+ def test_torch_lora_packing(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "JackFram/llama-68m",
+ "flash_attention": False,
+ "sdp_attention": False,
+ "sample_packing": True,
+ "pad_to_sequence_len": True,
+ "sequence_len": 1024,
+ "load_in_8bit": True,
+ "adapter": "lora",
+ "lora_r": 32,
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "lora_target_linear": True,
+ "val_set_size": 0.1,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 2,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch",
+ "lr_scheduler": "cosine",
+ "max_steps": 20,
+ "save_steps": 10,
+ "eval_steps": 10,
+ "fp16": True,
+ }
+ )
+ normalize_config(cfg)
+ cli_args = TrainerCliArgs()
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
+
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py
index 96ff5eee8..dda08a463 100644
--- a/tests/e2e/patched/test_fused_llama.py
+++ b/tests/e2e/patched/test_fused_llama.py
@@ -33,6 +33,7 @@ class TestFusedLlama(unittest.TestCase):
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
+ "pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py
index 203824fc9..837b4734f 100644
--- a/tests/e2e/utils.py
+++ b/tests/e2e/utils.py
@@ -4,7 +4,9 @@ helper utils for tests
import os
import shutil
import tempfile
+import unittest
from functools import wraps
+from importlib.metadata import version
from pathlib import Path
@@ -31,3 +33,15 @@ def most_recent_subdir(path):
subdir = max(subdirectories, key=os.path.getctime)
return subdir
+
+
+def require_torch_2_1_1(test_case):
+ """
+ Decorator marking a test that requires torch >= 2.1.1
+ """
+
+ def is_min_2_1_1():
+ torch_version = version("torch")
+ return torch_version >= "2.1.1"
+
+ return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py
index cce421e88..4521cd07b 100644
--- a/tests/monkeypatch/test_llama_attn_hijack_flash.py
+++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py
@@ -30,6 +30,20 @@ class TestMonkeyPatchUtils(unittest.TestCase):
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
+ def test_get_cu_seqlens_from_pos_ids_2d(self):
+ position_ids = torch.tensor(
+ [
+ [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
+ [0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
+ ]
+ )
+ target_res = torch.tensor(
+ [[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
+ )
+ self.assertTrue(
+ torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
+ )
+
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py
new file mode 100644
index 000000000..50f39d60f
--- /dev/null
+++ b/tests/test_packed_batch_sampler.py
@@ -0,0 +1,99 @@
+"""Module for testing streaming dataset sequence packing"""
+import pytest
+from datasets import concatenate_datasets, load_dataset
+from torch.utils.data import DataLoader, RandomSampler
+from transformers import AutoTokenizer
+
+from axolotl.datasets import TokenizedPromptDataset
+from axolotl.prompt_strategies.completion import load
+from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
+from axolotl.utils.dict import DictDefault
+from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
+
+
+@pytest.fixture(name="tokenizer")
+def fixture_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
+ tokenizer.pad_token = ""
+ return tokenizer
+
+
+@pytest.fixture(name="max_seq_length")
+def fixture_max_seq_length():
+ return 4096
+
+
+class TestBatchedSamplerPacking:
+ """
+ Test class for packing streaming dataset sequences
+ """
+
+ @pytest.mark.parametrize(
+ "batch_size, num_workers",
+ [
+ (1, 0),
+ (2, 0),
+ (1, 2),
+ (2, 2),
+ ],
+ )
+ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
+ import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
+
+ dataset = load_dataset(
+ "Trelis/tiny-shakespeare",
+ split="train",
+ )
+
+ cfg = DictDefault(
+ {
+ "train_on_inputs": True,
+ "sequence_len": max_seq_length,
+ }
+ )
+ ds_cfg = DictDefault(
+ {
+ "field": "Text",
+ }
+ )
+ completion_strategy = load(tokenizer, cfg, ds_cfg)
+ dataset_wrapper = TokenizedPromptDataset(
+ completion_strategy,
+ dataset,
+ )
+ train_dataset = concatenate_datasets([dataset_wrapper])
+ batch_sampler = MultipackBatchSampler(
+ sampler=RandomSampler(train_dataset),
+ batch_size=batch_size,
+ drop_last=True,
+ batch_max_len=max_seq_length,
+ lengths=get_dataset_lengths(train_dataset),
+ )
+
+ loader = DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
+ tokenizer=tokenizer,
+ padding=True,
+ pad_to_multiple_of=max_seq_length,
+ return_tensors="pt",
+ ),
+ num_workers=num_workers,
+ )
+ inputs = next(iter(loader))
+
+ assert inputs["input_ids"].shape == (batch_size, max_seq_length)
+ assert inputs["labels"].shape == (batch_size, max_seq_length)
+ assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
+
+ assert inputs["input_ids"].tolist()[0][0] == 2
+ assert inputs["labels"].tolist()[0][0] == -100
+ assert inputs["attention_mask"].tolist()[0][0] == 0
+ assert inputs["attention_mask"].tolist()[0][-1] > 1
+
+ if batch_size >= 2:
+ assert inputs["input_ids"].tolist()[1][0] == 2
+ assert inputs["labels"].tolist()[1][0] == -100
+ assert inputs["attention_mask"].tolist()[1][0] == 0
+ assert inputs["attention_mask"].tolist()[1][-1] > 1
diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py
index a47e3983f..be4b7075b 100644
--- a/tests/test_packed_pretraining.py
+++ b/tests/test_packed_pretraining.py
@@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se
from axolotl.utils.data import encode_packed_pretraining
-class TestPacking(unittest.TestCase):
+class TestPretrainingPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""