feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)
* Add prompt strategies * Update modified URL * Update modified URL * Update fastchat_conversation_turns.py * Update register function * Remove extra /n for system prompt * Fix return * Fix BOS * Update requirements, pylint * Linting * Linting * fix tuples, make sure to set system message in template * tests for llama3 tokenization * fix conditionals for loading chat template --------- Co-authored-by: Ram <ram@Rams-MacBook-Pro.local> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
|
|
||||||
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
if parsed_cfg.chat_template == "chatml":
|
||||||
LOG.info(
|
if parsed_cfg.default_system_message:
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
LOG.info(
|
||||||
)
|
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
)
|
||||||
else:
|
register_chatml_template(parsed_cfg.default_system_message)
|
||||||
register_chatml_template()
|
else:
|
||||||
|
register_chatml_template()
|
||||||
|
elif parsed_cfg.chat_template == "llama3":
|
||||||
|
if parsed_cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(parsed_cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
from axolotl.prompt_strategies.sharegpt import (
|
||||||
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
|
)
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
|
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||||
|
LOG.info(
|
||||||
|
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||||
|
)
|
||||||
|
register_llama3_template(cfg.default_system_message)
|
||||||
|
else:
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml",
|
name="chatml",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
@@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
|
|||||||
name="chatml_glaive",
|
name="chatml_glaive",
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
system_message=system_message,
|
system_message=system_message,
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_llama3_template(system_message=None):
|
||||||
|
system_message = system_message or "You are a helpful assistant."
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="llama3",
|
||||||
|
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||||
|
system_message=system_message,
|
||||||
|
roles=("user", "assistant"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA3,
|
||||||
|
sep="",
|
||||||
|
stop_str="<|eot_id|>",
|
||||||
|
stop_token_ids=[128001, 128009],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_loader(
|
def build_loader(
|
||||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
|||||||
"chatml": "<|im_start|>{ROLE}",
|
"chatml": "<|im_start|>{ROLE}",
|
||||||
"zephyr": "<|{ROLE}|>",
|
"zephyr": "<|{ROLE}|>",
|
||||||
"vicuna_v1.1": "{ROLE}",
|
"vicuna_v1.1": "{ROLE}",
|
||||||
|
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
|||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||||
|
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
if user_choice in templates:
|
if user_choice in templates:
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
|
|||||||
inst = "inst" # pylint: disable=invalid-name
|
inst = "inst" # pylint: disable=invalid-name
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
|
|||||||
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
|||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
GlaiveShareGPTPromptTokenizingStrategy,
|
||||||
SimpleShareGPTPromptTokenizingStrategy,
|
SimpleShareGPTPromptTokenizingStrategy,
|
||||||
register_chatml_template,
|
register_chatml_template,
|
||||||
|
register_llama3_template,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
from axolotl.prompters import ShareGPTPrompterV2
|
||||||
|
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
register_llama3_template()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
@pytest.fixture(name="sharegpt_dataset")
|
||||||
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TestSharegpt:
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptLlama3:
|
||||||
|
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
||||||
|
|
||||||
|
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
||||||
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation="llama3",
|
||||||
|
role_key_model=None,
|
||||||
|
role_key_human=None,
|
||||||
|
),
|
||||||
|
llama3_tokenizer,
|
||||||
|
False, # train_on_inputs
|
||||||
|
2048, # sequence_len
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
strategy, sharegpt_dataset, process_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
assert input_ids == [
|
||||||
|
128000, # bos
|
||||||
|
128006, 9125, 128007, # system header
|
||||||
|
271, 31724, 128009, # sys prompt, eot
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharegptChatML:
|
||||||
"""
|
"""
|
||||||
Test class for sharegpt prompter
|
Test class for sharegpt prompter
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user