From ef9bf7ad7378eeee7b25f08db406190cc81a3db3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 17 Jul 2023 06:17:28 -0400 Subject: [PATCH] fix expand mask for multiple batch items, make sure we pad position_ids --- src/axolotl/monkeypatch/llama_expand_mask.py | 34 ++++-- src/axolotl/utils/collators.py | 121 +++++++++++++++++++ src/axolotl/utils/trainer.py | 3 +- 3 files changed, 150 insertions(+), 8 deletions(-) create mode 100644 src/axolotl/utils/collators.py diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index 7e661a1cf..e79e7afc3 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -10,21 +10,41 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ + # Move the mask to the CPU + mask = mask.cpu() + bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - binary_mask = torch.where( - mask != 0, torch.tensor(1).to(torch.int16), torch.tensor(0).to(torch.int16) - ) + # Initialize a tensor to hold the expanded masks + expanded_masks = torch.zeros(bsz, 1, tgt_len, src_len).to(dtype) - zero_one_mask = torch.eq(mask, mask.t()).int() * binary_mask - expanded_mask = zero_one_mask.unsqueeze(0).expand(bsz, 1, tgt_len, src_len) + # For each sequence in the batch + for i in range(bsz): + # Get the mask for this sequence + mask_i = mask[i].unsqueeze(0) - inverted_mask = 1.0 - expanded_mask + # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one + binary_mask_i = torch.where( + mask_i != 0, + torch.tensor(1).to(dtype).cpu(), + torch.tensor(0).to(dtype).cpu(), + ) + + # Create a block-diagonal mask + zero_one_mask_i = torch.eq(mask_i, mask_i.t()).int() * binary_mask_i + + # Expand the mask + expanded_mask_i = zero_one_mask_i.unsqueeze(0).expand(1, 1, tgt_len, src_len) + + # Store the expanded mask + expanded_masks[i] = expanded_mask_i + + inverted_mask = 1.0 - expanded_masks return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) + ).cpu() def hijack_expand_mask(): diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py new file mode 100644 index 000000000..d7acdc977 --- /dev/null +++ b/src/axolotl/utils/collators.py @@ -0,0 +1,121 @@ +""" +DataCollator for axolotl to pad labels and position_ids for packed sequences +""" +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + + +@dataclass +class DataCollatorForSeq2Seq: + """ + Data collator that will dynamically pad the inputs received, as well as the labels and position_ids + + Args: + tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): + The tokenizer used for encoding the data. + model ([`PreTrainedModel`]): + The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to + prepare the *decoder_input_ids* + + This is useful when using *label_smoothing* to avoid calculating loss twice. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + + - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + label_pad_token_id (`int`, *optional*, defaults to -100): + The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). + return_tensors (`str`): + The type of Tensor to return. Allowable values are "np", "pt" and "tf". + """ + + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + labels = None + if return_tensors is None: + return_tensors = self.return_tensors + + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + feat = ( + [feature[feature_name] for feature in features] + if feature_name in features[0].keys() + else None + ) + labels = feat if feat and feature_name == "labels" else labels + # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the + # same length to return tensors. + if feat is not None: + max_feature_length = max(len(l) for l in feat) # noqa: E741 + if self.pad_to_multiple_of is not None: + max_feature_length = ( + (max_feature_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + for feature in features: + remainder = [pad_token_id] * ( + max_feature_length - len(feature[feature_name]) + ) + if isinstance(feature[feature_name], list): + feature[feature_name] = ( + feature[feature_name] + remainder + if padding_side == "right" + else remainder + feature[feature_name] + ) + elif padding_side == "right": + feature[feature_name] = np.concatenate( + [feature[feature_name], remainder] + ).astype(np.int64) + else: + feature[feature_name] = np.concatenate( + [remainder, feature[feature_name]] + ).astype(np.int64) + + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2144e6b02..9c398e21e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -21,6 +21,7 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, SavePeftModelCallback, ) +from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.schedulers import ( InterpolatingLogScheduler, get_cosine_schedule_with_quadratic_warmup, @@ -346,7 +347,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): train_dataset=train_dataset, eval_dataset=eval_dataset, args=training_args, - data_collator=transformers.DataCollatorForSeq2Seq( + data_collator=DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", **data_collator_kwargs,