support for true batches with multipack (#1230)
* support for true batches with multipack * patch the map dataset fetcher to handle batches with packed indexes * patch 4d mask creation for sdp attention * better handling for BetterTransformer * patch general case for 4d mask * setup forward patch. WIP * fix patch file * support for multipack w/o flash attention for llama * cleanup * add warning about bf16 vs fp16 for multipack with sdpa * bugfixes * add 4d multipack tests, refactor patches * update tests and add warnings * fix e2e file check * skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
@@ -37,6 +37,9 @@ Features:
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- Advanced Topics
|
||||
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||
- [Debugging Axolotl](#debugging-axolotl)
|
||||
|
||||
BIN
docs/images/4d-mask.png
Normal file
BIN
docs/images/4d-mask.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 239 KiB |
@@ -1,4 +1,11 @@
|
||||
# Multipack
|
||||
# Multipack (Sample Packing)
|
||||
|
||||
## Visualization of Multipack with Flash Attention
|
||||
|
||||
Because Flash Attention simply drops the attention mask, we do not need to
|
||||
construct a 4d attention mask. We only need to concatenate the sequences into
|
||||
a single batch and let flash attention know where each new sequence begins.
|
||||
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
@@ -49,3 +56,18 @@ w packing ( note it's the same effective number of tokens per step, but a true b
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
|
||||
cu_seqlens:
|
||||
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
|
||||
|
||||
|
||||
## Multipack without Flash Attention
|
||||
|
||||
Multipack can still be achieved without Flash attention, but with lower packing
|
||||
efficiency as we are not able to join multiple batches into a single batch due to
|
||||
context length limits without flash attention. We can use either Pytorch's Scaled
|
||||
Dot Product Attention implementation or native Pytorch attention implementation
|
||||
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
|
||||
to pack sequences together and avoid cross attention.
|
||||
|
||||
<img src="./images/4d-mask.png" alt="axolotl" width="800">
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -98,6 +98,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
@@ -229,11 +233,19 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -243,11 +255,19 @@ class AxolotlTrainer(Trainer):
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
@@ -860,6 +880,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.flash_attention is not True
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
@@ -964,6 +987,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
|
||||
0
src/axolotl/monkeypatch/data/__init__.py
Normal file
0
src/axolotl/monkeypatch/data/__init__.py
Normal file
46
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
Normal file
46
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
if (
|
||||
hasattr(self.dataset, "__getitems__")
|
||||
and self.dataset.__getitems__
|
||||
):
|
||||
data[i] = self.dataset.__getitems__(possibly_batched_index_)
|
||||
else:
|
||||
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
else:
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
else:
|
||||
data = self.dataset[possibly_batched_index]
|
||||
return self.collate_fn(data)
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
@@ -1,142 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# sdp-attn start
|
||||
#
|
||||
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# sdp-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@@ -5,38 +5,11 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
|
||||
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
26
src/axolotl/monkeypatch/llama_patch_multipack.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
import transformers.modeling_attn_mask_utils
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -1,8 +1,15 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -89,7 +96,6 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
@@ -135,7 +141,18 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
# Find the maximum value across all tensors
|
||||
max_value = max(t.max() for t in results)
|
||||
|
||||
# Find the length of the longest tensor
|
||||
max_length = max(t.size(0) for t in results)
|
||||
|
||||
# Pad each tensor to the same length and collect them in a list
|
||||
padded_results = [
|
||||
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
|
||||
]
|
||||
|
||||
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
@@ -149,3 +166,62 @@ def set_module_name(model, name, value):
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -24,6 +23,11 @@ from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -124,7 +128,7 @@ def train(
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
@@ -149,7 +153,10 @@ def train(
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
@@ -195,7 +202,7 @@ def train(
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
@@ -132,24 +132,26 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -159,24 +161,26 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features)
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -202,6 +202,20 @@ def validate_config(cfg):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
@@ -350,17 +364,24 @@ def validate_config(cfg):
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
|
||||
@@ -834,7 +834,7 @@ def encode_packed_pretraining(
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
batch_size=1,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
@@ -842,15 +842,16 @@ def encode_packed_pretraining(
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
for data in sampler:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
for batch in sampler:
|
||||
for data in batch:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
@@ -324,13 +323,13 @@ def load_model(
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
||||
hijack_llama_sdp_attention,
|
||||
elif cfg.sample_packing:
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
)
|
||||
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
elif cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
@@ -506,6 +505,12 @@ def load_model(
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif cfg.sdp_attention:
|
||||
model_kwargs["attn_implementation"] = "sdpa"
|
||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||
elif cfg.eager_attention:
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
try:
|
||||
if (
|
||||
@@ -749,6 +754,8 @@ def load_model(
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
|
||||
@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = None
|
||||
self.batch_size = batch_size
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
@@ -147,7 +147,13 @@ class MultipackBatchSampler(BatchSampler):
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
batches = [
|
||||
[
|
||||
[indices[b_idx] for b_idx in batch]
|
||||
for batch in batches[i : i + self.batch_size]
|
||||
]
|
||||
for i in range(0, len(batches), self.batch_size)
|
||||
]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
@@ -189,7 +195,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.batch_max_len
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
)
|
||||
- 1
|
||||
),
|
||||
|
||||
@@ -237,11 +237,17 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention:
|
||||
batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=cfg.micro_batch_size,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
@@ -249,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
data_loader_len = len(data_loader) // batch_size
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
|
||||
114
tests/e2e/patched/test_4d_multipack_llama.py
Normal file
114
tests/e2e/patched/test_4d_multipack_llama.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
E2E tests for multipack fft llama using 4d attention masks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class Test4dMultipackLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_1_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": False,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
@@ -33,6 +33,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
|
||||
@@ -4,7 +4,9 @@ helper utils for tests
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -31,3 +33,15 @@ def most_recent_subdir(path):
|
||||
subdir = max(subdirectories, key=os.path.getctime)
|
||||
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
|
||||
@@ -30,6 +30,20 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_2d(self):
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
|
||||
]
|
||||
)
|
||||
target_res = torch.tensor(
|
||||
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_max_seqlen_in_batch(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||
|
||||
99
tests/test_packed_batch_sampler.py
Normal file
99
tests/test_packed_batch_sampler.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import pytest
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
"Trelis/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"train_on_inputs": True,
|
||||
"sequence_len": max_seq_length,
|
||||
}
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "Text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
completion_strategy,
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
|
||||
tokenizer=tokenizer,
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_seq_length,
|
||||
return_tensors="pt",
|
||||
),
|
||||
num_workers=num_workers,
|
||||
)
|
||||
inputs = next(iter(loader))
|
||||
|
||||
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["labels"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
|
||||
|
||||
assert inputs["input_ids"].tolist()[0][0] == 2
|
||||
assert inputs["labels"].tolist()[0][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[0][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[0][-1] > 1
|
||||
|
||||
if batch_size >= 2:
|
||||
assert inputs["input_ids"].tolist()[1][0] == 2
|
||||
assert inputs["labels"].tolist()[1][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[1][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[1][-1] > 1
|
||||
@@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se
|
||||
from axolotl.utils.data import encode_packed_pretraining
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user