fix expand mask for multiple batch items, make sure we pad position_ids
This commit is contained in:
@@ -10,21 +10,41 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
"""
|
"""
|
||||||
|
# Move the mask to the CPU
|
||||||
|
mask = mask.cpu()
|
||||||
|
|
||||||
bsz, src_len = mask.size()
|
bsz, src_len = mask.size()
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
binary_mask = torch.where(
|
# Initialize a tensor to hold the expanded masks
|
||||||
mask != 0, torch.tensor(1).to(torch.int16), torch.tensor(0).to(torch.int16)
|
expanded_masks = torch.zeros(bsz, 1, tgt_len, src_len).to(dtype)
|
||||||
)
|
|
||||||
|
|
||||||
zero_one_mask = torch.eq(mask, mask.t()).int() * binary_mask
|
# For each sequence in the batch
|
||||||
expanded_mask = zero_one_mask.unsqueeze(0).expand(bsz, 1, tgt_len, src_len)
|
for i in range(bsz):
|
||||||
|
# Get the mask for this sequence
|
||||||
|
mask_i = mask[i].unsqueeze(0)
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
binary_mask_i = torch.where(
|
||||||
|
mask_i != 0,
|
||||||
|
torch.tensor(1).to(dtype).cpu(),
|
||||||
|
torch.tensor(0).to(dtype).cpu(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a block-diagonal mask
|
||||||
|
zero_one_mask_i = torch.eq(mask_i, mask_i.t()).int() * binary_mask_i
|
||||||
|
|
||||||
|
# Expand the mask
|
||||||
|
expanded_mask_i = zero_one_mask_i.unsqueeze(0).expand(1, 1, tgt_len, src_len)
|
||||||
|
|
||||||
|
# Store the expanded mask
|
||||||
|
expanded_masks[i] = expanded_mask_i
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_masks
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
return inverted_mask.masked_fill(
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
)
|
).cpu()
|
||||||
|
|
||||||
|
|
||||||
def hijack_expand_mask():
|
def hijack_expand_mask():
|
||||||
|
|||||||
121
src/axolotl/utils/collators.py
Normal file
121
src/axolotl/utils/collators.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForSeq2Seq:
|
||||||
|
"""
|
||||||
|
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||||
|
The tokenizer used for encoding the data.
|
||||||
|
model ([`PreTrainedModel`]):
|
||||||
|
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
||||||
|
prepare the *decoder_input_ids*
|
||||||
|
|
||||||
|
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
||||||
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||||
|
among:
|
||||||
|
|
||||||
|
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
||||||
|
sequence is provided).
|
||||||
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||||
|
acceptable input length for the model if that argument is not provided.
|
||||||
|
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
||||||
|
max_length (`int`, *optional*):
|
||||||
|
Maximum length of the returned list and optionally padding length (see above).
|
||||||
|
pad_to_multiple_of (`int`, *optional*):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
|
||||||
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||||
|
7.5 (Volta).
|
||||||
|
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||||
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
|
return_tensors (`str`):
|
||||||
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
model: Optional[Any] = None
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
label_pad_token_id: int = -100
|
||||||
|
position_pad_token_id: int = 0
|
||||||
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
labels = None
|
||||||
|
if return_tensors is None:
|
||||||
|
return_tensors = self.return_tensors
|
||||||
|
|
||||||
|
for feature_name, pad_token_id in [
|
||||||
|
("labels", self.label_pad_token_id),
|
||||||
|
("position_ids", self.position_pad_token_id),
|
||||||
|
]:
|
||||||
|
feat = (
|
||||||
|
[feature[feature_name] for feature in features]
|
||||||
|
if feature_name in features[0].keys()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
labels = feat if feat and feature_name == "labels" else labels
|
||||||
|
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||||
|
# same length to return tensors.
|
||||||
|
if feat is not None:
|
||||||
|
max_feature_length = max(len(l) for l in feat) # noqa: E741
|
||||||
|
if self.pad_to_multiple_of is not None:
|
||||||
|
max_feature_length = (
|
||||||
|
(max_feature_length + self.pad_to_multiple_of - 1)
|
||||||
|
// self.pad_to_multiple_of
|
||||||
|
* self.pad_to_multiple_of
|
||||||
|
)
|
||||||
|
|
||||||
|
padding_side = self.tokenizer.padding_side
|
||||||
|
for feature in features:
|
||||||
|
remainder = [pad_token_id] * (
|
||||||
|
max_feature_length - len(feature[feature_name])
|
||||||
|
)
|
||||||
|
if isinstance(feature[feature_name], list):
|
||||||
|
feature[feature_name] = (
|
||||||
|
feature[feature_name] + remainder
|
||||||
|
if padding_side == "right"
|
||||||
|
else remainder + feature[feature_name]
|
||||||
|
)
|
||||||
|
elif padding_side == "right":
|
||||||
|
feature[feature_name] = np.concatenate(
|
||||||
|
[feature[feature_name], remainder]
|
||||||
|
).astype(np.int64)
|
||||||
|
else:
|
||||||
|
feature[feature_name] = np.concatenate(
|
||||||
|
[remainder, feature[feature_name]]
|
||||||
|
).astype(np.int64)
|
||||||
|
|
||||||
|
features = self.tokenizer.pad(
|
||||||
|
features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare decoder_input_ids
|
||||||
|
if (
|
||||||
|
labels is not None
|
||||||
|
and self.model is not None
|
||||||
|
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||||
|
):
|
||||||
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||||
|
labels=features["labels"]
|
||||||
|
)
|
||||||
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
return features
|
||||||
@@ -21,6 +21,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
InterpolatingLogScheduler,
|
InterpolatingLogScheduler,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
@@ -346,7 +347,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user