diff --git a/README.md b/README.md index 80c0fedff..a7be20dcd 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ Features: - [Inference](#inference) - [Merge LORA to Base](#merge-lora-to-base) - [Special Tokens](#special-tokens) +- Advanced Topics + - [Multipack](./docs/multipack.md) + - [RLHF & DPO](./docs/rlhf.md) - [Common Errors](#common-errors-) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Debugging Axolotl](#debugging-axolotl) diff --git a/docs/images/4d-mask.png b/docs/images/4d-mask.png new file mode 100644 index 000000000..181e693cf Binary files /dev/null and b/docs/images/4d-mask.png differ diff --git a/docs/multipack.md b/docs/multipack.md index 2a55148b2..bee13b62c 100644 --- a/docs/multipack.md +++ b/docs/multipack.md @@ -1,4 +1,11 @@ -# Multipack +# Multipack (Sample Packing) + +## Visualization of Multipack with Flash Attention + +Because Flash Attention simply drops the attention mask, we do not need to +construct a 4d attention mask. We only need to concatenate the sequences into +a single batch and let flash attention know where each new sequence begins. + 4k context, bsz =4, each character represents 256 tokens @@ -49,3 +56,18 @@ w packing ( note it's the same effective number of tokens per step, but a true b E E E E F F F F F G G G H H H H I I I J J J J K K K K K L L L X ]] ``` + +cu_seqlens: +[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]] + + +## Multipack without Flash Attention + +Multipack can still be achieved without Flash attention, but with lower packing +efficiency as we are not able to join multiple batches into a single batch due to +context length limits without flash attention. We can use either Pytorch's Scaled +Dot Product Attention implementation or native Pytorch attention implementation +along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539) +to pack sequences together and avoid cross attention. + +axolotl diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index c8aea4a71..636a23ba5 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -6,6 +6,7 @@ import logging from dataclasses import dataclass, field from typing import Optional +import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 96054dc50..5e52dafab 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -98,6 +98,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=False, metadata={"help": "Use sample packing for efficient training."}, ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) eval_sample_packing: Optional[bool] = field( default=None, metadata={"help": "Use sample packing for efficient evals."}, @@ -229,11 +233,19 @@ class AxolotlTrainer(Trainer): def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_train_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + batch_max_len = ( + self.args.per_device_train_batch_size * self.args.max_seq_length + ) return MultipackBatchSampler( RandomSampler(self.train_dataset), - self.args.train_batch_size, + batch_size=batch_size, drop_last=True, - batch_max_len=self._train_batch_size * self.args.max_seq_length, + batch_max_len=batch_max_len, lengths=get_dataset_lengths(self.train_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) @@ -243,11 +255,19 @@ class AxolotlTrainer(Trainer): self, eval_dataset: Dataset ) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and self.args.eval_sample_packing is not False: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_eval_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + batch_max_len = ( + self.args.per_device_eval_batch_size * self.args.max_seq_length + ) return MultipackBatchSampler( SequentialSampler(eval_dataset), - self.args.per_device_eval_batch_size, + batch_size=batch_size, drop_last=True, - batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, + batch_max_len=batch_max_len, lengths=get_dataset_lengths(eval_dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, ) @@ -860,6 +880,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["sample_packing"] = ( self.cfg.sample_packing if self.cfg.sample_packing else False ) + training_arguments_kwargs["multipack_real_batches"] = ( + self.cfg.flash_attention is not True + ) training_arguments_kwargs["eval_sample_packing"] = ( self.cfg.sample_packing if self.cfg.eval_sample_packing is not False @@ -964,6 +987,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if use_batch_sampler_collator: if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]: collator = V2BatchSamplerDataCollatorForSeq2Seq + elif ( + self.cfg.model_config_type in ["llama"] + and self.cfg.flash_attention is not True + ): + collator = V2BatchSamplerDataCollatorForSeq2Seq else: collator = BatchSamplerDataCollatorForSeq2Seq else: diff --git a/src/axolotl/monkeypatch/data/__init__.py b/src/axolotl/monkeypatch/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py new file mode 100644 index 000000000..2e9364e3a --- /dev/null +++ b/src/axolotl/monkeypatch/data/batch_dataset_fetcher.py @@ -0,0 +1,46 @@ +"""monkey patches for the dataset fetcher to handle batches of packed indexes""" +# pylint: disable=protected-access + +import torch +from torch.utils.data._utils.fetch import _BaseDatasetFetcher +from torch.utils.data._utils.worker import _worker_loop + + +class _MapDatasetFetcher(_BaseDatasetFetcher): + def fetch(self, possibly_batched_index): + if isinstance(possibly_batched_index[0], list): + data = [None for i in possibly_batched_index] + for i, possibly_batched_index_ in enumerate(possibly_batched_index): + if self.auto_collation: + if ( + hasattr(self.dataset, "__getitems__") + and self.dataset.__getitems__ + ): + data[i] = self.dataset.__getitems__(possibly_batched_index_) + else: + data[i] = [self.dataset[idx] for idx in possibly_batched_index_] + else: + data[i] = self.dataset[possibly_batched_index_] + else: + if self.auto_collation: + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data = self.dataset.__getitems__(possibly_batched_index) + else: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] + return self.collate_fn(data) + + +def patch_fetchers(): + torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher + torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher + + +def patched_worker_loop(*args, **kwargs): + patch_fetchers() + return _worker_loop(*args, **kwargs) + + +torch.utils.data._utils.worker._worker_loop = patched_worker_loop +patch_fetchers() diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py b/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py deleted file mode 100644 index cfed8cb17..000000000 --- a/src/axolotl/monkeypatch/llama_attn_hijack_sdp.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention -""" - -import warnings -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -import transformers.models.llama.modeling_llama -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv - - -def hijack_llama_sdp_attention(): - transformers.models.llama.modeling_llama.LlamaAttention.forward = ( - sdp_attention_forward - ) - - -def sdp_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument - **kwargs, # pylint: disable=unused-argument -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # pylint: disable=duplicate-code - bsz, q_len, _ = hidden_states.size() - - if not hasattr(self, "pretraining_tp"): - self.pretraining_tp = 1 - - if self.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if output_attentions: - warnings.warn( - "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." - ) - - # - # sdp-attn start - # - - with torch.backends.cuda.sdp_kernel(): - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - # - # sdp-attn end - # - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1 - ) - attn_output = sum( - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ) - else: - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index d69433baa..5738bb543 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -5,38 +5,11 @@ from typing import Optional import torch +from axolotl.monkeypatch.utils import mask_2d_to_4d + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. - This expansion transforms the mask to lower triangular form to prevent future peeking. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask.expand(bsz, 1, tgt_len, src_len) - - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask = torch.where( - mask != 0, - torch.tensor(1).to(dtype), - torch.tensor(0).to(dtype), - ) - - # Create a block-diagonal mask. - # we multiply by the binary mask so that 0's in the original mask are correctly excluded - zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - - # Now let's create a lower triangular mask of ones that will zero out the upper triangular part - lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( - mask.device - ) - - # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask - masked_zero_one_mask = zero_one_mask * lower_triangular_ones + masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len) inverted_mask = 1.0 - masked_zero_one_mask return inverted_mask.masked_fill( diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py new file mode 100644 index 000000000..540c5577a --- /dev/null +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -0,0 +1,26 @@ +""" +Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention +""" + +from axolotl.monkeypatch.utils import ( + patched_prepare_4d_causal_attention_mask, + patched_prepare_4d_causal_attention_mask_for_sdpa, +) + + +def hijack_llama_prepare_4d_mask(): + import transformers.modeling_attn_mask_utils + import transformers.models.llama.modeling_llama + + transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask_for_sdpa + ) + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask_for_sdpa + ) + transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask + ) + transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask + ) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 193c9d0ec..63141635a 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -1,8 +1,15 @@ """ Shared utils for the monkeypatches """ +from typing import Optional + import torch import torch.nn.functional as F +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.utils import is_torch_bf16_gpu_available @torch.jit.script @@ -89,7 +96,6 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) -@torch.jit.script def get_cu_seqlens_from_pos_ids(position_ids): """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: @@ -135,7 +141,18 @@ def get_cu_seqlens_from_pos_ids(position_ids): results.append(cu_seqlens) max_seq_lens.append(max_seq_len) - return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) + # Find the maximum value across all tensors + max_value = max(t.max() for t in results) + + # Find the length of the longest tensor + max_length = max(t.size(0) for t in results) + + # Pad each tensor to the same length and collect them in a list + padded_results = [ + F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results + ] + + return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens) def set_module_name(model, name, value): @@ -149,3 +166,62 @@ def set_module_name(model, name, value): child_name = name setattr(parent, child_name, value) + + +def mask_2d_to_4d( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + This expansion handles packed sequences so that sequences share the same attention mask integer value + when they attend to each other within that sequence. + This expansion transforms the mask to lower triangular form to prevent future peeking. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + mask = mask.unsqueeze(1).unsqueeze(2) + mask = mask.expand(bsz, 1, tgt_len, src_len) + + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask = torch.where( + mask != 0, + torch.tensor(1).to(dtype), + torch.tensor(0).to(dtype), + ) + + # Create a block-diagonal mask. + # we multiply by the binary mask so that 0's in the original mask are correctly excluded + zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask + + # Now let's create a lower triangular mask of ones that will zero out the upper triangular part + lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( + mask.device + ) + + # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask + masked_zero_one_mask = zero_one_mask * lower_triangular_ones + + return masked_zero_one_mask + + +def patched_prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) + + +def patched_prepare_4d_causal_attention_mask_for_sdpa( + attention_mask: Optional[torch.Tensor], + *args, +): + dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 + return _prepare_4d_causal_attention_mask_for_sdpa( + mask_2d_to_4d(attention_mask, dtype=dtype), + *args, + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d0f58bca9..014c281ec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -11,7 +11,6 @@ import torch import transformers.modelcard from accelerate.logging import get_logger from datasets import Dataset -from optimum.bettertransformer import BetterTransformer from peft import PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer @@ -24,6 +23,11 @@ from axolotl.utils.freeze import freeze_parameters_except from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer +try: + from optimum.bettertransformer import BetterTransformer +except ImportError: + BetterTransformer = None + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) @@ -124,7 +128,7 @@ def train( if cfg.local_rank == 0: def terminate_handler(_, __, model): - if cfg.flash_optimum: + if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) sys.exit(0) @@ -149,7 +153,10 @@ def train( pretrain_hooks(cfg, trainer) if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( - enable_flash=True, enable_math=True, enable_mem_efficient=True + # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, ): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: @@ -195,7 +202,7 @@ def train( state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), ) elif cfg.local_rank == 0: - if cfg.flash_optimum: + if cfg.flash_optimum and BetterTransformer: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index da1ee392f..8512b9408 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -132,24 +132,26 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ def __call__(self, features, return_tensors=None): - chunked_data = {} - for feature in features[0].keys(): - if feature == "length": - continue - if feature == "attention_mask": - arrays = [ - (1) * np.array(item[feature]) - for item in features - if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) for item in features if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - features = [chunked_data] - return super().__call__(features, return_tensors=return_tensors) + if not isinstance(features[0], list): + features = [features] + out_features = [{} for _ in features] + for i, features_ in enumerate(features): + for feature in features_[0].keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (1) * np.array(item[feature]) + for i, item in enumerate(features_) + if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in features_ if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) @dataclass @@ -159,24 +161,26 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ def __call__(self, features, return_tensors=None): - chunked_data = {} - for feature in features[0].keys(): - if feature == "length": - continue - if feature == "attention_mask": - arrays = [ - (i + 1) * np.array(item[feature]) - for i, item in enumerate(features) - if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) for item in features if feature in item - ] - chunked_data[feature] = np.concatenate(arrays) - features = [chunked_data] - return super().__call__(features, return_tensors=return_tensors) + if not isinstance(features[0], list): + features = [features] + out_features = [{} for _ in features] + for i, features_ in enumerate(features): + for feature in features_[0].keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features_) + if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in features_ if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) @dataclass diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 3ea48f6bb..75b4e5220 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -202,6 +202,20 @@ def validate_config(cfg): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." ) + if ( + # pylint: disable=too-many-boolean-expressions + not (cfg.bf16 or cfg.bfloat16) + and (cfg.fp16 or cfg.float16) + and not cfg.adapter + and not cfg.flash_attention + and cfg.sample_packing + ): + LOG.warning( + "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." + ) + # ValueError: Attempting to unscale FP16 gradients. + # OR + # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half if cfg.max_packed_sequence_len: raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") @@ -350,17 +364,24 @@ def validate_config(cfg): + "point to its path, and remove model_revision from the config." ) - if cfg.sample_packing and cfg.sdp_attention: - # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 - raise ValueError( - "sample_packing not compatible with sdp_attention. Use flash_attention" - ) + # if cfg.sample_packing and cfg.sdp_attention: + # # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2 + # raise ValueError( + # "sample_packing not compatible with sdp_attention. Use flash_attention" + # ) if cfg.sample_packing and cfg.xformers_attention: raise ValueError( "sample_packing not compatible with xformers_attention. Use flash_attention" ) + if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16): + # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 + LOG.warning( + "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " + "This may work on H100s." + ) + if cfg.early_stopping_patience: if not cfg.save_steps or not cfg.eval_steps: raise ValueError( diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 39766868e..5e6ceb6cb 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -834,7 +834,7 @@ def encode_packed_pretraining( sampler = MultipackBatchSampler( RandomSampler(train_dataset), - batch_size=batch_size, + batch_size=1, drop_last=True, batch_max_len=batch_size * max_seq_length, lengths=get_dataset_lengths(train_dataset), @@ -842,15 +842,16 @@ def encode_packed_pretraining( chunked_data = defaultdict(list) - for data in sampler: - features = train_dataset[data] - features["labels"] = features["input_ids"].copy() - collated_features = collate_fn(features) + for batch in sampler: + for data in batch: + features = train_dataset[data] + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) - for feature in features.keys(): - if feature == "length": - continue - chunked_data[feature].append(collated_features[feature].squeeze(0)) + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) return chunked_data diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ad3ac7bc6..224e3a258 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -8,7 +8,6 @@ import addict import bitsandbytes as bnb import torch import transformers -from optimum.bettertransformer import BetterTransformer from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training from peft.tuners.lora import QuantLinear from transformers import ( # noqa: F401 @@ -324,13 +323,13 @@ def load_model( LOG.info("patching with xformers attention") hijack_llama_attention() - elif cfg.sdp_attention: - from axolotl.monkeypatch.llama_attn_hijack_sdp import ( - hijack_llama_sdp_attention, + elif cfg.sample_packing: + from axolotl.monkeypatch.llama_patch_multipack import ( + hijack_llama_prepare_4d_mask, ) - LOG.info("patching with sdp attention") - hijack_llama_sdp_attention() + LOG.info("patching llama _prepare_4d_causal_attention_mask*") + hijack_llama_prepare_4d_mask() elif cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." @@ -506,6 +505,12 @@ def load_model( model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) + elif cfg.sdp_attention: + model_kwargs["attn_implementation"] = "sdpa" + model_config._attn_implementation = "sdpa" # pylint: disable=protected-access + elif cfg.eager_attention: + model_kwargs["attn_implementation"] = "eager" + model_config._attn_implementation = "eager" # pylint: disable=protected-access try: if ( @@ -749,6 +754,8 @@ def load_model( model.config.use_cache = False if cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) if cfg.adapter is not None: diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 629a1a44c..cf47d9639 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler): packing_efficiency_estimate: float = 1.0, ): super().__init__(sampler, batch_size, drop_last) - self.batch_size = None + self.batch_size = batch_size self.batch_max_len = batch_max_len self.lengths: np.ndarray = lengths self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 @@ -147,7 +147,13 @@ class MultipackBatchSampler(BatchSampler): n=1, ) - batches = [[indices[b_idx] for b_idx in batch] for batch in batches] + batches = [ + [ + [indices[b_idx] for b_idx in batch] + for batch in batches[i : i + self.batch_size] + ] + for i in range(0, len(batches), self.batch_size) + ] # statistics if set_stats: @@ -189,7 +195,7 @@ class MultipackBatchSampler(BatchSampler): 0.99 * lengths_sum_per_device / self.packing_efficiency_estimate - // self.batch_max_len + // (self.batch_max_len * self.batch_size) ) - 1 ), diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 26197b991..6d4168fbd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -237,11 +237,17 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): main_process_only=True, ) else: + if cfg.flash_attention: + batch_size = 1 + batch_max_len = cfg.micro_batch_size * cfg.sequence_len + else: + batch_size = cfg.micro_batch_size + batch_max_len = cfg.sequence_len sampler = MultipackBatchSampler( sampler=RandomSampler(train_dataset), - batch_size=cfg.micro_batch_size, + batch_size=batch_size, drop_last=True, - batch_max_len=cfg.micro_batch_size * cfg.sequence_len, + batch_max_len=batch_max_len, lengths=get_dataset_lengths(train_dataset), ) @@ -249,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): train_dataset.remove_columns(["length"]), batch_sampler=sampler, ) - data_loader_len = len(data_loader) + data_loader_len = len(data_loader) // batch_size actual_eff = sampler.efficiency() LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) # FIXME: is there a bug here somewhere? the total num steps depends diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py new file mode 100644 index 000000000..d74d09723 --- /dev/null +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -0,0 +1,114 @@ +""" +E2E tests for multipack fft llama using 4d attention masks +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from ..utils import require_torch_2_1_1, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class Test4dMultipackLlama(unittest.TestCase): + """ + Test case for Llama models using 4d attention with multipack + """ + + @require_torch_2_1_1 + @with_temp_dir + def test_sdp_lora_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "flash_attention": False, + "sdp_attention": True, + "sample_packing": True, + "pad_to_sequence_len": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "sequence_len": 1024, + "val_set_size": 0.1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "fp16": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() + + @with_temp_dir + def test_torch_lora_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "flash_attention": False, + "sdp_attention": False, + "sample_packing": True, + "pad_to_sequence_len": True, + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "eval_steps": 10, + "fp16": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 96ff5eee8..dda08a463 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -33,6 +33,7 @@ class TestFusedLlama(unittest.TestCase): { "base_model": "JackFram/llama-68m", "flash_attention": True, + "pad_to_sequence_len": True, "flash_attn_fuse_qkv": True, "flash_attn_fuse_mlp": True, "sample_packing": True, diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 203824fc9..837b4734f 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -4,7 +4,9 @@ helper utils for tests import os import shutil import tempfile +import unittest from functools import wraps +from importlib.metadata import version from pathlib import Path @@ -31,3 +33,15 @@ def most_recent_subdir(path): subdir = max(subdirectories, key=os.path.getctime) return subdir + + +def require_torch_2_1_1(test_case): + """ + Decorator marking a test that requires torch >= 2.1.1 + """ + + def is_min_2_1_1(): + torch_version = version("torch") + return torch_version >= "2.1.1" + + return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case) diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py index cce421e88..4521cd07b 100644 --- a/tests/monkeypatch/test_llama_attn_hijack_flash.py +++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py @@ -30,6 +30,20 @@ class TestMonkeyPatchUtils(unittest.TestCase): torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) ) + def test_get_cu_seqlens_from_pos_ids_2d(self): + position_ids = torch.tensor( + [ + [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0], + [0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0], + ] + ) + target_res = torch.tensor( + [[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32 + ) + self.assertTrue( + torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) + ) + def test_get_max_seqlen_in_batch(self): attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32) diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py new file mode 100644 index 000000000..50f39d60f --- /dev/null +++ b/tests/test_packed_batch_sampler.py @@ -0,0 +1,99 @@ +"""Module for testing streaming dataset sequence packing""" +import pytest +from datasets import concatenate_datasets, load_dataset +from torch.utils.data import DataLoader, RandomSampler +from transformers import AutoTokenizer + +from axolotl.datasets import TokenizedPromptDataset +from axolotl.prompt_strategies.completion import load +from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.dict import DictDefault +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths + + +@pytest.fixture(name="tokenizer") +def fixture_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") + tokenizer.pad_token = "" + return tokenizer + + +@pytest.fixture(name="max_seq_length") +def fixture_max_seq_length(): + return 4096 + + +class TestBatchedSamplerPacking: + """ + Test class for packing streaming dataset sequences + """ + + @pytest.mark.parametrize( + "batch_size, num_workers", + [ + (1, 0), + (2, 0), + (1, 2), + (2, 2), + ], + ) + def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): + import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 + + dataset = load_dataset( + "Trelis/tiny-shakespeare", + split="train", + ) + + cfg = DictDefault( + { + "train_on_inputs": True, + "sequence_len": max_seq_length, + } + ) + ds_cfg = DictDefault( + { + "field": "Text", + } + ) + completion_strategy = load(tokenizer, cfg, ds_cfg) + dataset_wrapper = TokenizedPromptDataset( + completion_strategy, + dataset, + ) + train_dataset = concatenate_datasets([dataset_wrapper]) + batch_sampler = MultipackBatchSampler( + sampler=RandomSampler(train_dataset), + batch_size=batch_size, + drop_last=True, + batch_max_len=max_seq_length, + lengths=get_dataset_lengths(train_dataset), + ) + + loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg + tokenizer=tokenizer, + padding=True, + pad_to_multiple_of=max_seq_length, + return_tensors="pt", + ), + num_workers=num_workers, + ) + inputs = next(iter(loader)) + + assert inputs["input_ids"].shape == (batch_size, max_seq_length) + assert inputs["labels"].shape == (batch_size, max_seq_length) + assert inputs["attention_mask"].shape == (batch_size, max_seq_length) + + assert inputs["input_ids"].tolist()[0][0] == 2 + assert inputs["labels"].tolist()[0][0] == -100 + assert inputs["attention_mask"].tolist()[0][0] == 0 + assert inputs["attention_mask"].tolist()[0][-1] > 1 + + if batch_size >= 2: + assert inputs["input_ids"].tolist()[1][0] == 2 + assert inputs["labels"].tolist()[1][0] == -100 + assert inputs["attention_mask"].tolist()[1][0] == 0 + assert inputs["attention_mask"].tolist()[1][-1] > 1 diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index a47e3983f..be4b7075b 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se from axolotl.utils.data import encode_packed_pretraining -class TestPacking(unittest.TestCase): +class TestPretrainingPacking(unittest.TestCase): """ Test class for packing streaming dataset sequences """