From 50421c8b1dd63cceef0a06e9700a1c3e834685f2 Mon Sep 17 00:00:00 2001 From: Ram <9160496+0-hero@users.noreply.github.com> Date: Sat, 11 May 2024 09:38:04 +0530 Subject: [PATCH] feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553) * Add prompt strategies * Update modified URL * Update modified URL * Update fastchat_conversation_turns.py * Update register function * Remove extra /n for system prompt * Fix return * Fix BOS * Update requirements, pylint * Linting * Linting * fix tuples, make sure to set system message in template * tests for llama3 tokenization * fix conditionals for loading chat template --------- Co-authored-by: Ram Co-authored-by: Wing Lian --- src/axolotl/cli/preprocess.py | 28 ++++++++--- src/axolotl/cli/train.py | 13 ++++- src/axolotl/prompt_strategies/sharegpt.py | 20 +++++++- src/axolotl/prompters.py | 1 + src/axolotl/utils/chat_templates.py | 1 + .../config/models/input/v0_4_1/__init__.py | 1 + tests/prompt_strategies/test_sharegpt.py | 50 ++++++++++++++++++- 7 files changed, 102 insertions(+), 12 deletions(-) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index fa71d6793..e7b3596a4 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -19,7 +19,10 @@ from axolotl.cli import ( ) from axolotl.common.cli import PreprocessCliArgs from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.prompt_strategies.sharegpt import register_chatml_template +from axolotl.prompt_strategies.sharegpt import ( + register_chatml_template, + register_llama3_template, +) LOG = logging.getLogger("axolotl.cli.preprocess") @@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return_remaining_strings=True ) - if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message: - LOG.info( - f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" - ) - register_chatml_template(parsed_cfg.default_system_message) - else: - register_chatml_template() + if parsed_cfg.chat_template == "chatml": + if parsed_cfg.default_system_message: + LOG.info( + f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_chatml_template(parsed_cfg.default_system_message) + else: + register_chatml_template() + elif parsed_cfg.chat_template == "llama3": + if parsed_cfg.default_system_message: + LOG.info( + f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}" + ) + register_llama3_template(parsed_cfg.default_system_message) + else: + register_llama3_template() if not parsed_cfg.dataset_prepared_path: msg = ( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 0cebe5a52..7bb4a5184 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -19,7 +19,10 @@ from axolotl.cli import ( print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs -from axolotl.prompt_strategies.sharegpt import register_chatml_template +from axolotl.prompt_strategies.sharegpt import ( + register_chatml_template, + register_llama3_template, +) from axolotl.train import train LOG = logging.getLogger("axolotl.cli.train") @@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: register_chatml_template() + if cfg.chat_template == "llama3" and cfg.default_system_message: + LOG.info( + f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}" + ) + register_llama3_template(cfg.default_system_message) + else: + register_llama3_template() + if cfg.rl: # and cfg.rl != "orpo": dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index b556b3583..5f0e7a895 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -22,7 +22,7 @@ def register_chatml_template(system_message=None): name="chatml", system_template="<|im_start|>system\n{system_message}", system_message=system_message, - roles=["<|im_start|>user", "<|im_start|>assistant"], + roles=("<|im_start|>user", "<|im_start|>assistant"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) @@ -32,13 +32,29 @@ def register_chatml_template(system_message=None): name="chatml_glaive", system_template="<|im_start|>system\n{system_message}", system_message=system_message, - roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], + roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", ) ) +def register_llama3_template(system_message=None): + system_message = system_message or "You are a helpful assistant." + register_conv_template( + Conversation( + name="llama3", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + system_message=system_message, + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str="<|eot_id|>", + stop_token_ids=[128001, 128009], + ) + ) + + def build_loader( tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"], prompter_cls: Type["ShareGPTPrompterV2"], diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 7a089c0ec..60ea5c99f 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = { "chatml": "<|im_start|>{ROLE}", "zephyr": "<|{ROLE}|>", "vicuna_v1.1": "{ROLE}", + "llama3": "<|start_header_id|>{ROLE}<|end_header_id|>", } diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index c1dde8c0f..01b147356 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -24,6 +24,7 @@ def chat_templates(user_choice: str): "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}", } if user_choice in templates: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 419deee58..0a2442d50 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -143,6 +143,7 @@ class ChatTemplate(str, Enum): inst = "inst" # pylint: disable=invalid-name gemma = "gemma" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name class LoftQConfig(BaseModel): diff --git a/tests/prompt_strategies/test_sharegpt.py b/tests/prompt_strategies/test_sharegpt.py index 3ff0eab05..6e6909834 100644 --- a/tests/prompt_strategies/test_sharegpt.py +++ b/tests/prompt_strategies/test_sharegpt.py @@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import ( GlaiveShareGPTPromptTokenizingStrategy, SimpleShareGPTPromptTokenizingStrategy, register_chatml_template, + register_llama3_template, ) from axolotl.prompters import ShareGPTPrompterV2 register_chatml_template() +register_llama3_template() @pytest.fixture(name="sharegpt_dataset") @@ -115,7 +117,53 @@ def fixture_tokenizer(): return tokenizer -class TestSharegpt: +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + tokenizer.eos_token = "<|eot_id|>" + + return tokenizer + + +class TestSharegptLlama3: + """Test class for ShareGPT style datasets with llama-3 prompts""" + + def test_tokenization(self, sharegpt_dataset, llama3_tokenizer): + strategy = SimpleShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation="llama3", + role_key_model=None, + role_key_human=None, + ), + llama3_tokenizer, + False, # train_on_inputs + 2048, # sequence_len + ) + + dataset_wrapper = TokenizedPromptDataset( + strategy, sharegpt_dataset, process_count=1 + ) + + input_ids = dataset_wrapper[0]["input_ids"] + + # fmt: off + assert input_ids == [ + 128000, # bos + 128006, 9125, 128007, # system header + 271, 31724, 128009, # sys prompt, eot + 128006, 882, 128007, # user header + 271, 15339, 128009, # user prompt eot + 128006, 78191, 128007, # assistant header + 271, 15339, 128009, # assistant response eot + 128006, 882, 128007, + 271, 19045, 29474, 128009, + 128006, 78191, 128007, + 271, 19045, 29474, 128009, + ] + # fmt: on + + +class TestSharegptChatML: """ Test class for sharegpt prompter """