Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing * set position ids and use block diagonal attn mask * fix expand mask for multiple batch items, make sure we pad position_ids * don't move masks to cpu * use multi pack dataloader w random sampler * add position_ids back * more fixes for dataloader integration * est total tokens, fix field loop * more fixes, position_ids seems broken * more fixes for sample packing * use distributed sampler, avoid accelerate prepare * use accelerator prepare for dataloader * fix for position_ids w packing * Update src/axolotl/utils/dataloader.py * validation for sample packing and doc * more fixes for 4k and optimizations * optimized expand mask fn * better handling of variance in multipack dataloader length and trainer hanging when it runs out of data * fix rounding of len of batches to int * better handling so that all devices have the same dataloader len * fix step calc for packing * pass sample packing efficiency to training args * add a test for the mask expansion for sequence packing * only process eval dataset for packing if not None * don't split batches when packing * weighted CE losses * weighted CEL fixes * limit packing to sequences of max seq len * seq_len_multiple for packing * make sure the chunk size is an int * sample_packing_seq_len_multiplier config * use cumulative seq len with var len flash attn v2 w packing * properly calculate max len * fix flash-attn, xformers, packing, support chatml * fix chatml system prompt for openorca, legacy tokenizer opts * add chatml * add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test * fix test and pylint checks * more packing and dataset optimizations and fixes * filter w multiple cpus * more fixes and optimizations * fixes and go back to distributed sampler since batch sampler won't work * fix counts by accounting for num devices * fix steps calculation * previous accelerate is still most performant * add numba to requirements. * use custom distributed checks * fix sampler to prevent overfit w new epochs * let's not cleanup the cached datasets * calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier * speed optimizations and set accelerate fsdp env vars * optimize dataset concatenation? * more optimizations for dataset handling * fix import for annotation * manual pre-commit fixes * another sum optimization and bug fix for calc steps * fix packing estimations * fix formatting * pylint problems * add back flash attention branch for handling unpacked sequences seperately * Address PR feedback * add optional sample packing config params to readme
This commit is contained in:
@@ -16,6 +16,7 @@ class PromptStyle(Enum):
|
||||
|
||||
INSTRUCT = "instruct"
|
||||
CHAT = "chat"
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class AlpacaPrompter:
|
||||
@@ -25,6 +26,7 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
@@ -34,14 +36,23 @@ class AlpacaPrompter:
|
||||
self.match_prompt_style()
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||
self.turn_no_input_format = (
|
||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||
)
|
||||
self.system_format = "### System:\n{system}\n\n"
|
||||
if self.prompt_style == PromptStyle.CHAT.value:
|
||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||
self.system_format = "SYSTEM: {system}\n"
|
||||
if self.prompt_style == PromptStyle.CHATML.value:
|
||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.turn_no_input_format = (
|
||||
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user