From f93f0017cd85119401443fce63012391633f1bff Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 4 Aug 2023 10:09:16 -0400 Subject: [PATCH] fix flash-attn, xformers, packing, support chatml --- .../monkeypatch/llama_attn_hijack_flash.py | 12 +++++++++ .../monkeypatch/llama_attn_hijack_xformers.py | 3 ++- src/axolotl/monkeypatch/llama_expand_mask.py | 14 +++++++--- .../prompt_strategies/alpaca_w_system.py | 26 ++++++++++++++++--- src/axolotl/prompters.py | 4 +++ src/axolotl/utils/dataloader.py | 2 +- tests/test_expand_mask.py | 8 +++--- 7 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 74f5d0f9e..ada4ce73e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -38,6 +38,18 @@ def get_cu_seqlens(attn_mask): ) # Calculate the sequence lengths seq_lengths = change_indices[1:] - change_indices[:-1] + # Calculate the length of the final sequence or padding + final_seq_length = attn_mask.shape[1] - change_indices[-1] + # Append the length of the final sequence or padding to seq_lengths + if final_seq_length.item(): + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [final_seq_length.item()], dtype=torch.int32, device=device + ), + ] + ) # Calculate the cumulative sequence lengths cu_seqlens = torch.cat( [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 02525b7f5..4755db30b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -128,7 +128,8 @@ def xformers_forward( query_states, key_states, value_states, - attn_bias=xformers.ops.LowerTriangularMask(), + attn_bias=attention_mask, + # attn_bias=xformers.ops.LowerTriangularMask(), ) attn_weights = None else: diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py index 3bea39531..d69433baa 100644 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ b/src/axolotl/monkeypatch/llama_expand_mask.py @@ -10,7 +10,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. This should result in a block diagonal mask + when they attend to each other within that sequence. + This expansion transforms the mask to lower triangular form to prevent future peeking. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len @@ -29,9 +30,14 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] # we multiply by the binary mask so that 0's in the original mask are correctly excluded zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - # Expand the mask to the correct dimensions for the current batch index - expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len) - inverted_mask = 1.0 - expanded_mask + # Now let's create a lower triangular mask of ones that will zero out the upper triangular part + lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( + mask.device + ) + + # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask + masked_zero_one_mask = zero_one_mask * lower_triangular_ones + inverted_mask = 1.0 - masked_zero_one_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index ea7151366..8875ec7bc 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter): ) -> Generator[str, None, None]: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. - formatted_sys_prompt = f"### System:\n{system}\n\n" if system else "" + formatted_sys_prompt = ( + self.system_format.format(system=system) + if system and self.system_format + else "" + ) if input: res = formatted_sys_prompt + self.turn_format.format( instruction=instruction, input=input @@ -90,8 +94,15 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter): self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" if self.prompt_style == PromptStyle.CHAT.value: - self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" - self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" + self.turn_format = "User: {instruction}\n{input}\nAssistant:" + self.turn_no_input_format = "User: {instruction}\nAssistant:" + self.system_format = "System: {system}\n" + if self.prompt_style == PromptStyle.CHATML.value: + self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" + self.turn_no_input_format = ( + "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n" + ) + self.system_format = "<|im_start|>{system}<|im_end|>\n" class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): @@ -137,3 +148,12 @@ def load_open_orca(tokenizer, cfg): cfg.train_on_inputs, cfg.sequence_len, ) + + +def load_open_orca_chatml(tokenizer, cfg): + return OpenOrcaPromptTokenizingStrategy( + OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index a304bd137..4ce7fb65d 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -16,6 +16,7 @@ class PromptStyle(Enum): INSTRUCT = "instruct" CHAT = "chat" + CHATML = "chatml" class AlpacaPrompter: @@ -25,6 +26,7 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + system_format: str turn_format: str turn_no_input_format: str prompt_style: Optional[PromptStyle] = None @@ -39,9 +41,11 @@ class AlpacaPrompter: self.turn_no_input_format = ( "### Instruction:\n{instruction}\n\n### Response:\n" ) + self.system_format = "### System:\n{system}\n\n" if self.prompt_style == PromptStyle.CHAT.value: self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" + self.system_format = "SYSTEM:{system}\n" def build_prompt( self, diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index d8ba7f567..fbd22eb57 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -259,7 +259,7 @@ class MultipackDistributedDataloader: f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"actual packing efficiency: {self.efficiency()}" ) - return self._len_est() + return max(1, self._len_est()) def efficiency(self): return self.eff_total_used / self.eff_total_slots diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py index 2d943b6c4..885d7f0d7 100644 --- a/tests/test_expand_mask.py +++ b/tests/test_expand_mask.py @@ -20,8 +20,8 @@ class TestExpandMask(unittest.TestCase): [ [ [ - [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], - [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], + [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00], ] @@ -29,14 +29,14 @@ class TestExpandMask(unittest.TestCase): [ [ [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], - [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], + [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38], [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], ] ], ] ) - + print(repr(_expand_mask(mask, dtype))) # Check that the output matches the expected output self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))