set position ids and use block diagonal attn mask
This commit is contained in:
@@ -77,7 +77,12 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
self.tokens_dtype = torch.int64
|
self.tokens_dtype = torch.int64
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
buffer = {
|
||||||
|
"input_ids": [],
|
||||||
|
"attention_mask": [],
|
||||||
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
idx = 0
|
idx = 0
|
||||||
@@ -108,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||||
: self.seq_length
|
: self.seq_length
|
||||||
]
|
]
|
||||||
|
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||||
|
: self.seq_length
|
||||||
|
]
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
if labels.size() == input_ids.size() and (
|
if labels.size() == input_ids.size() and (
|
||||||
attention_mask.size() == input_ids.size()
|
attention_mask.size() == input_ids.size()
|
||||||
@@ -116,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -125,6 +134,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": [],
|
"input_ids": [],
|
||||||
"attention_mask": [],
|
"attention_mask": [],
|
||||||
"labels": [],
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
}
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
idx = 1
|
idx = 1
|
||||||
@@ -151,8 +161,12 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
labels_with_concat = torch.tensor(
|
labels_with_concat = torch.tensor(
|
||||||
labels, dtype=self.tokens_dtype
|
labels, dtype=self.tokens_dtype
|
||||||
)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
len(input_ids), dtype=self.tokens_dtype
|
||||||
|
)
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||||
buffer["labels"].append(labels_with_concat)
|
buffer["labels"].append(labels_with_concat)
|
||||||
|
buffer["position_ids"].append(position_ids)
|
||||||
buffer_len += len(input_ids)
|
buffer_len += len(input_ids)
|
||||||
|
|||||||
35
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
35
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
binary_mask = torch.where(
|
||||||
|
mask != 0, torch.tensor(1).to(torch.int16), torch.tensor(0).to(torch.int16)
|
||||||
|
)
|
||||||
|
|
||||||
|
zero_one_mask = torch.eq(mask, mask.t()).int() * binary_mask
|
||||||
|
expanded_mask = zero_one_mask.unsqueeze(0).expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_expand_mask():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
||||||
|
_expand_mask
|
||||||
|
)
|
||||||
@@ -86,8 +86,10 @@ def load_model(
|
|||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
cfg.is_llama_derived_model = (
|
||||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
"llama" in base_model
|
||||||
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
|
or cfg.is_llama_derived_model is True
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
@@ -132,6 +134,12 @@ def load_model(
|
|||||||
LOG.info("patching with xpos rope")
|
LOG.info("patching with xpos rope")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.max_packed_sequence_len:
|
||||||
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
|
LOG.info("patching _expand_mask")
|
||||||
|
hijack_expand_mask()
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if cfg.bf16 or cfg.bfloat16:
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||||
|
|||||||
@@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase):
|
|||||||
# first example doesn't have mask reset
|
# first example doesn't have mask reset
|
||||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||||
assert example["attention_mask"][0] == 1
|
assert example["attention_mask"][0] == 1
|
||||||
|
assert example["position_ids"][0] == 0
|
||||||
|
assert example["position_ids"][1] == 1
|
||||||
|
|
||||||
# but subsequent one does
|
# but subsequent one does
|
||||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||||
assert example["attention_mask"][next_bos_index] == 2
|
assert example["attention_mask"][next_bos_index] == 2
|
||||||
|
assert example["position_ids"][next_bos_index] == 0
|
||||||
|
assert example["position_ids"][next_bos_index + 1] == 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user