fix flash-attn, xformers, packing, support chatml
This commit is contained in:
@@ -38,6 +38,18 @@ def get_cu_seqlens(attn_mask):
|
|||||||
)
|
)
|
||||||
# Calculate the sequence lengths
|
# Calculate the sequence lengths
|
||||||
seq_lengths = change_indices[1:] - change_indices[:-1]
|
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||||
|
# Calculate the length of the final sequence or padding
|
||||||
|
final_seq_length = attn_mask.shape[1] - change_indices[-1]
|
||||||
|
# Append the length of the final sequence or padding to seq_lengths
|
||||||
|
if final_seq_length.item():
|
||||||
|
seq_lengths = torch.cat(
|
||||||
|
[
|
||||||
|
seq_lengths,
|
||||||
|
torch.tensor(
|
||||||
|
[final_seq_length.item()], dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
# Calculate the cumulative sequence lengths
|
# Calculate the cumulative sequence lengths
|
||||||
cu_seqlens = torch.cat(
|
cu_seqlens = torch.cat(
|
||||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
|||||||
@@ -128,7 +128,8 @@ def xformers_forward(
|
|||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
attn_bias=attention_mask,
|
||||||
|
# attn_bias=xformers.ops.LowerTriangularMask(),
|
||||||
)
|
)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||||
when they attend to each other within that sequence. This should result in a block diagonal mask
|
when they attend to each other within that sequence.
|
||||||
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
bsz, src_len = mask.size()
|
bsz, src_len = mask.size()
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
@@ -29,9 +30,14 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||||
|
|
||||||
# Expand the mask to the correct dimensions for the current batch index
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||||
expanded_mask = zero_one_mask.expand(bsz, 1, tgt_len, src_len)
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||||
inverted_mask = 1.0 - expanded_mask
|
mask.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||||
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||||
|
inverted_mask = 1.0 - masked_zero_one_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
return inverted_mask.masked_fill(
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
|||||||
@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
|
|||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
|
formatted_sys_prompt = (
|
||||||
|
self.system_format.format(system=system)
|
||||||
|
if system and self.system_format
|
||||||
|
else ""
|
||||||
|
)
|
||||||
if input:
|
if input:
|
||||||
res = formatted_sys_prompt + self.turn_format.format(
|
res = formatted_sys_prompt + self.turn_format.format(
|
||||||
instruction=instruction, input=input
|
instruction=instruction, input=input
|
||||||
@@ -90,8 +94,15 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
|||||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "User: {instruction}\nAssistant:"
|
||||||
|
self.system_format = "System: {system}\n"
|
||||||
|
if self.prompt_style == PromptStyle.CHATML.value:
|
||||||
|
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.turn_no_input_format = (
|
||||||
|
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
self.system_format = "<|im_start|>{system}<|im_end|>\n"
|
||||||
|
|
||||||
|
|
||||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||||
@@ -137,3 +148,12 @@ def load_open_orca(tokenizer, cfg):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_open_orca_chatml(tokenizer, cfg):
|
||||||
|
return OpenOrcaPromptTokenizingStrategy(
|
||||||
|
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class PromptStyle(Enum):
|
|||||||
|
|
||||||
INSTRUCT = "instruct"
|
INSTRUCT = "instruct"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
CHATML = "chatml"
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter:
|
class AlpacaPrompter:
|
||||||
@@ -25,6 +26,7 @@ class AlpacaPrompter:
|
|||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||||
|
system_format: str
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
prompt_style: Optional[PromptStyle] = None
|
prompt_style: Optional[PromptStyle] = None
|
||||||
@@ -39,9 +41,11 @@ class AlpacaPrompter:
|
|||||||
self.turn_no_input_format = (
|
self.turn_no_input_format = (
|
||||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
)
|
)
|
||||||
|
self.system_format = "### System:\n{system}\n\n"
|
||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||||
|
self.system_format = "SYSTEM:{system}\n"
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ class MultipackDistributedDataloader:
|
|||||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
f"actual packing efficiency: {self.efficiency()}"
|
f"actual packing efficiency: {self.efficiency()}"
|
||||||
)
|
)
|
||||||
return self._len_est()
|
return max(1, self._len_est())
|
||||||
|
|
||||||
def efficiency(self):
|
def efficiency(self):
|
||||||
return self.eff_total_used / self.eff_total_slots
|
return self.eff_total_used / self.eff_total_slots
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ class TestExpandMask(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
[
|
[
|
||||||
[
|
[
|
||||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
||||||
]
|
]
|
||||||
@@ -29,14 +29,14 @@ class TestExpandMask(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
[
|
[
|
||||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
print(repr(_expand_mask(mask, dtype)))
|
||||||
# Check that the output matches the expected output
|
# Check that the output matches the expected output
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user