* fix attetion mask with packing * set position ids and use block diagonal attn mask * fix expand mask for multiple batch items, make sure we pad position_ids * don't move masks to cpu * use multi pack dataloader w random sampler * add position_ids back * more fixes for dataloader integration * est total tokens, fix field loop * more fixes, position_ids seems broken * more fixes for sample packing * use distributed sampler, avoid accelerate prepare * use accelerator prepare for dataloader * fix for position_ids w packing * Update src/axolotl/utils/dataloader.py * validation for sample packing and doc * more fixes for 4k and optimizations * optimized expand mask fn * better handling of variance in multipack dataloader length and trainer hanging when it runs out of data * fix rounding of len of batches to int * better handling so that all devices have the same dataloader len * fix step calc for packing * pass sample packing efficiency to training args * add a test for the mask expansion for sequence packing * only process eval dataset for packing if not None * don't split batches when packing * weighted CE losses * weighted CEL fixes * limit packing to sequences of max seq len * seq_len_multiple for packing * make sure the chunk size is an int * sample_packing_seq_len_multiplier config * use cumulative seq len with var len flash attn v2 w packing * properly calculate max len * fix flash-attn, xformers, packing, support chatml * fix chatml system prompt for openorca, legacy tokenizer opts * add chatml * add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test * fix test and pylint checks * more packing and dataset optimizations and fixes * filter w multiple cpus * more fixes and optimizations * fixes and go back to distributed sampler since batch sampler won't work * fix counts by accounting for num devices * fix steps calculation * previous accelerate is still most performant * add numba to requirements. * use custom distributed checks * fix sampler to prevent overfit w new epochs * let's not cleanup the cached datasets * calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier * speed optimizations and set accelerate fsdp env vars * optimize dataset concatenation? * more optimizations for dataset handling * fix import for annotation * manual pre-commit fixes * another sum optimization and bug fix for calc steps * fix packing estimations * fix formatting * pylint problems * add back flash attention branch for handling unpacked sequences seperately * Address PR feedback * add optional sample packing config params to readme
125 lines
4.3 KiB
Python
125 lines
4.3 KiB
Python
"""Module testing prompters"""
|
|
|
|
import unittest
|
|
|
|
from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
|
|
from axolotl.prompters import (
|
|
AlpacaPrompter,
|
|
MultipleChoiceExplainPrompter,
|
|
PromptStyle,
|
|
UnpromptedPrompter,
|
|
)
|
|
|
|
|
|
class AlpacaPrompterTest(unittest.TestCase):
|
|
"""
|
|
Test AlpacaPrompter
|
|
"""
|
|
|
|
def test_prompt_style_w_none(self):
|
|
prompter = AlpacaPrompter(prompt_style=None)
|
|
res = next(prompter.build_prompt("tell me a joke"))
|
|
# just testing that it uses instruct style
|
|
assert "### Instruction:" in res
|
|
|
|
def test_prompt_style_w_instruct(self):
|
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
|
res = next(
|
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
|
)
|
|
assert "Below is an instruction" in res
|
|
assert "### Instruction:" in res
|
|
assert "### Input:" in res
|
|
assert "alpacas" in res
|
|
assert "### Response:" in res
|
|
assert "USER:" not in res
|
|
assert "ASSISTANT:" not in res
|
|
res = next(prompter.build_prompt("tell me a joke about the following"))
|
|
assert "Below is an instruction" in res
|
|
assert "### Instruction:" in res
|
|
assert "### Input:" not in res
|
|
assert "### Response:" in res
|
|
assert "USER:" not in res
|
|
assert "ASSISTANT:" not in res
|
|
|
|
def test_prompt_style_w_chat(self):
|
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
|
|
res = next(
|
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
|
)
|
|
assert "Below is an instruction" in res
|
|
assert "### Instruction:" not in res
|
|
assert "### Input:" not in res
|
|
assert "alpacas" in res
|
|
assert "### Response:" not in res
|
|
assert "USER:" in res
|
|
assert "ASSISTANT:" in res
|
|
res = next(prompter.build_prompt("tell me a joke about the following"))
|
|
assert "Below is an instruction" in res
|
|
assert "### Instruction:" not in res
|
|
assert "### Input:" not in res
|
|
assert "### Response:" not in res
|
|
assert "USER:" in res
|
|
assert "ASSISTANT:" in res
|
|
|
|
def test_system_prompt(self):
|
|
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
|
|
res = next(
|
|
prompter.build_prompt_w_system(
|
|
"use cot", "tell me a joke about the following", "alpacas"
|
|
)
|
|
)
|
|
assert "use cot" in res
|
|
assert res.startswith("SYSTEM:")
|
|
assert "### Instruction:" not in res
|
|
assert "### Input:" not in res
|
|
assert "alpacas" in res
|
|
assert "### Response:" not in res
|
|
assert "USER:" in res
|
|
assert "ASSISTANT:" in res
|
|
|
|
|
|
class UnpromptedPrompterTest(unittest.TestCase):
|
|
"""
|
|
Test class for UnpromptedPrompter with no system prompts
|
|
"""
|
|
|
|
def test_prompt_style_w_none(self):
|
|
prompter = UnpromptedPrompter(prompt_style=None)
|
|
res = next(prompter.build_prompt("tell me a joke"))
|
|
assert "### Instruction:" in res
|
|
assert "tell me a joke" in res
|
|
assert res.startswith("###")
|
|
|
|
def test_prompt_style_w_instruct(self):
|
|
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
|
res = next(
|
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
|
)
|
|
assert "### Instruction:" in res
|
|
assert "tell me a joke" in res
|
|
assert res.startswith("###")
|
|
|
|
def test_prompt_style_w_chat(self):
|
|
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
|
|
res = next(
|
|
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
|
)
|
|
assert "USER:" in res
|
|
assert "tell me a joke" in res
|
|
assert res.startswith("USER:")
|
|
|
|
|
|
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
|
|
"""
|
|
Test class for MultipleChoiceExplainPrompter
|
|
"""
|
|
|
|
def test_prompt_style_w_chat(self):
|
|
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
|
|
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
|
|
assert "USER:" in res
|
|
assert "choose one" in res
|
|
assert "Choose the answer that best answers the question." in res
|
|
assert "- A\n- B\n- C" in res
|