fix flash-attn, xformers, packing, support chatml

This commit is contained in:
Wing Lian
2023-08-04 10:09:16 -04:00
parent 0b01da0713
commit f93f0017cd
7 changed files with 56 additions and 13 deletions

View File

@@ -38,6 +38,18 @@ def get_cu_seqlens(attn_mask):
) )
# Calculate the sequence lengths # Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1] seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = attn_mask.shape[1] - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths # Calculate the cumulative sequence lengths
cu_seqlens = torch.cat( cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]

View File

@@ -128,7 +128,8 @@ def xformers_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_bias=xformers.ops.LowerTriangularMask(), attn_bias=attention_mask,
# attn_bias=xformers.ops.LowerTriangularMask(),
) )
attn_weights = None attn_weights = None
else: else:

View File

@@ -10,7 +10,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence. This should result in a block diagonal mask when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
""" """
bsz, src_len = mask.size() bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
@@ -29,9 +30,14 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# we multiply by the binary mask so that 0's in the original mask are correctly excluded # we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Expand the mask to the correct dimensions for the current batch index # Now let's create a lower triangular mask of ones that will zero out the upper triangular part
expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len) lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
inverted_mask = 1.0 - expanded_mask mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill( return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min inverted_mask.to(torch.bool), torch.finfo(dtype).min

View File

@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else "" formatted_sys_prompt = (
self.system_format.format(system=system)
if system and self.system_format
else ""
)
if input: if input:
res = formatted_sys_prompt + self.turn_format.format( res = formatted_sys_prompt + self.turn_format.format(
instruction=instruction, input=input instruction=instruction, input=input
@@ -90,8 +94,15 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n" self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n" self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>{system}<|im_end|>\n"
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy): class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
@@ -137,3 +148,12 @@ def load_open_orca(tokenizer, cfg):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
def load_open_orca_chatml(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -16,6 +16,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct" INSTRUCT = "instruct"
CHAT = "chat" CHAT = "chat"
CHATML = "chatml"
class AlpacaPrompter: class AlpacaPrompter:
@@ -25,6 +26,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str
turn_format: str turn_format: str
turn_no_input_format: str turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None prompt_style: Optional[PromptStyle] = None
@@ -39,9 +41,11 @@ class AlpacaPrompter:
self.turn_no_input_format = ( self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n" "### Instruction:\n{instruction}\n\n### Response:\n"
) )
self.system_format = "### System:\n{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value: if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM:{system}\n"
def build_prompt( def build_prompt(
self, self,

View File

@@ -259,7 +259,7 @@ class MultipackDistributedDataloader:
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}" f"actual packing efficiency: {self.efficiency()}"
) )
return self._len_est() return max(1, self._len_est())
def efficiency(self): def efficiency(self):
return self.eff_total_used / self.eff_total_slots return self.eff_total_used / self.eff_total_slots

View File

@@ -20,8 +20,8 @@ class TestExpandMask(unittest.TestCase):
[ [
[ [
[ [
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00], [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
] ]
@@ -29,14 +29,14 @@ class TestExpandMask(unittest.TestCase):
[ [
[ [
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
] ]
], ],
] ]
) )
print(repr(_expand_mask(mask, dtype)))
# Check that the output matches the expected output # Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))