fix flash-attn, xformers, packing, support chatml

This commit is contained in:
Wing Lian
2023-08-04 10:09:16 -04:00
parent 0b01da0713
commit f93f0017cd
7 changed files with 56 additions and 13 deletions

View File

@@ -38,6 +38,18 @@ def get_cu_seqlens(attn_mask):
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = attn_mask.shape[1] - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]

View File

@@ -128,7 +128,8 @@ def xformers_forward(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
attn_bias=attention_mask,
# attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:

View File

@@ -10,7 +10,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence. This should result in a block diagonal mask
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
@@ -29,9 +30,14 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Expand the mask to the correct dimensions for the current batch index
expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len)
inverted_mask = 1.0 - expanded_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min

View File

@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
formatted_sys_prompt = (
self.system_format.format(system=system)
if system and self.system_format
else ""
)
if input:
res = formatted_sys_prompt + self.turn_format.format(
instruction=instruction, input=input
@@ -90,8 +94,15 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>{system}<|im_end|>\n"
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
@@ -137,3 +148,12 @@ def load_open_orca(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca_chatml(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -16,6 +16,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
class AlpacaPrompter:
@@ -25,6 +26,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
@@ -39,9 +41,11 @@ class AlpacaPrompter:
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "### System:\n{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM:{system}\n"
def build_prompt(
self,

View File

@@ -259,7 +259,7 @@ class MultipackDistributedDataloader:
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return self._len_est()
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -20,8 +20,8 @@ class TestExpandMask(unittest.TestCase):
[
[
[
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
]
@@ -29,14 +29,14 @@ class TestExpandMask(unittest.TestCase):
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
]
)
print(repr(_expand_mask(mask, dtype)))
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))