Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
NanoCode012
2024-10-14 10:36:08 +07:00
35 changed files with 1774 additions and 32 deletions

View File

View File

View File

@@ -0,0 +1,197 @@
"""
Tests for the chat messages module
"""
import unittest
import pytest
from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
@pytest.fixture(scope="session", name="llama_tokenizer")
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")
def llama_tokenizer_w_chatml(llama_tokenizer):
llama_tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
llama_tokenizer.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return llama_tokenizer
@pytest.fixture(scope="session", name="chat_msgs")
def chat_msgs_fixture():
return {
"conversation": [
{
"role": "system",
"content": [
{"type": "text", "value": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "value": "What is today's stock price of Apple?"},
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_call",
"value": {
"name": "get_date",
"arguments": {},
},
},
{
"type": "tool_call",
"value": {
"name": "get_stock_price",
"arguments": {"symbol": "AAPL"},
},
},
],
"weight": 1,
},
{
"role": "tool",
"content": [
{
"type": "tool_response",
"value": {
"name": "get_date",
"content": {"date": "2024-09-09"},
},
},
{
"type": "tool_response",
"value": {
"name": "get_stock_price",
"content": {"symbol": "AAPL", "price": 123.45},
},
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"value": "The stock price of Apple is $123.45.\n",
"weight": 0,
},
{
"type": "text",
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
},
{
"type": "text",
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
},
],
"weight": 1,
},
]
}
class TestMessagesCase:
"""
Test cases for the chat messages module
"""
def test_tool_call_stringify(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
chat_msgs_as_obj.conversation[2].content[1].value
)
def test_chatml_formatted_wrapper(self, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
target_chatml = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is today's stock price of Apple?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_date", "arguments": {}}
</tool_call>
<tool_call>
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
</tool_call>
<|im_end|>
<|im_start|>tool
<tool_response>
{"name": "get_date", "content": {"date": "2024-09-09"}}
</tool_response>
<tool_response>
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
</tool_response>
<|im_end|>
<|im_start|>assistant
The stock price of Apple is $123.45.
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
assert target_chatml == str(chat_msg_formatted)
def test_chatml_formatting_tool_call(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
assert target_chatml_turn2 == str(
format_message(chat_msgs_as_obj.conversation[2])
)
def test_train_labels(self, chatml_tokenizer, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
128256, # <|im_end|>
-100 # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
# also test if indivudal contents are set not to train
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
4513, 13, 1774, 13,
128256, # <|im_end|>
-100, # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
if __name__ == "__main__":
unittest.main()

View File

@@ -19,6 +19,8 @@ from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
@@ -346,3 +348,115 @@ class TestMultiGPULlama(unittest.TestCase):
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -0,0 +1,74 @@
"""
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraLlama(unittest.TestCase):
"""
Test case for Llama reward models using LoRA
"""
@with_temp_dir
def test_rm_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"model_type": "AutoModelForSequenceClassification",
"tokenizer_type": "LlamaTokenizer",
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
},
],
"remove_unused_columns": False,
"max_steps": 10,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -0,0 +1,62 @@
"""
tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import logging
import unittest
from axolotl.prompt_strategies.messages.chat import load
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestMessagesChatLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load(
llama3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"field_messages": "messages",
}
),
)
res = strategy.wrap_dataset(assistant_dataset)
input_ids = res[0]["input_ids"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
if __name__ == "__main__":
unittest.main()