From 4964b0d345f284b1305d926314524b98c504f470 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 17 Jul 2023 01:56:32 -0400 Subject: [PATCH] set position ids and use block diagonal attn mask --- src/axolotl/datasets.py | 16 ++++++++- src/axolotl/monkeypatch/llama_expand_mask.py | 35 ++++++++++++++++++++ src/axolotl/utils/models.py | 12 +++++-- tests/test_packed_dataset.py | 4 +++ 4 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/monkeypatch/llama_expand_mask.py diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 10656c8c7..4376fb18a 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -77,7 +77,12 @@ class ConstantLengthDataset(IterableDataset): self.tokens_dtype = torch.int64 def __iter__(self): - buffer = {"input_ids": [], "attention_mask": [], "labels": []} + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "position_ids": [], + } buffer_len = 0 for dataset in self.datasets: idx = 0 @@ -108,6 +113,9 @@ class ConstantLengthDataset(IterableDataset): attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ : self.seq_length ] + position_ids = torch.cat(buffer["position_ids"], dim=-1)[ + : self.seq_length + ] labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] if labels.size() == input_ids.size() and ( attention_mask.size() == input_ids.size() @@ -116,6 +124,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, + "position_ids": position_ids, } else: LOG.warning( @@ -125,6 +134,7 @@ class ConstantLengthDataset(IterableDataset): "input_ids": [], "attention_mask": [], "labels": [], + "position_ids": [], } buffer_len = 0 idx = 1 @@ -151,8 +161,12 @@ class ConstantLengthDataset(IterableDataset): labels_with_concat = torch.tensor( labels, dtype=self.tokens_dtype ) + position_ids = torch.arange( + len(input_ids), dtype=self.tokens_dtype + ) buffer["input_ids"].append(input_ids_with_concat) buffer["attention_mask"].append(attention_mask_with_concat) buffer["labels"].append(labels_with_concat) + buffer["position_ids"].append(position_ids) buffer_len += len(input_ids) diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py new file mode 100644 index 000000000..7e661a1cf --- /dev/null +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -0,0 +1,35 @@ +""" +expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf +""" +from typing import Optional + +import torch + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + binary_mask = torch.where( + mask != 0, torch.tensor(1).to(torch.int16), torch.tensor(0).to(torch.int16) + ) + + zero_one_mask = torch.eq(mask, mask.t()).int() * binary_mask + expanded_mask = zero_one_mask.unsqueeze(0).expand(bsz, 1, tgt_len, src_len) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +def hijack_expand_mask(): + import transformers + + transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access + _expand_mask + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7501878ba..6a39eede1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -86,8 +86,10 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit - cfg.is_llama_derived_model = "llama" in base_model or ( - cfg.model_type and "llama" in cfg.model_type.lower() + cfg.is_llama_derived_model = ( + "llama" in base_model + or (cfg.model_type and "llama" in cfg.model_type.lower()) + or cfg.is_llama_derived_model is True ) if cfg.is_llama_derived_model and cfg.flash_attention: @@ -132,6 +134,12 @@ def load_model( LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() + if cfg.is_llama_derived_model and cfg.max_packed_sequence_len: + from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask + + LOG.info("patching _expand_mask") + hijack_expand_mask() + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index a0f476a42..da8fb7a93 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase): # first example doesn't have mask reset assert example["input_ids"][0] == self.tokenizer.bos_token_id assert example["attention_mask"][0] == 1 + assert example["position_ids"][0] == 0 + assert example["position_ids"][1] == 1 # but subsequent one does assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id assert example["attention_mask"][next_bos_index] == 2 + assert example["position_ids"][next_bos_index] == 0 + assert example["position_ids"][next_bos_index + 1] == 1 if __name__ == "__main__":