Ignore generation/endgeneration tags when analyzing Jinja chat template (#2787)
* ignore generation/endgeneration tags Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error. * feat: add phi4 tokenizer test and unblock gemma2 * fix: improve template * chore: refactor * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
committed by
GitHub
parent
34da391391
commit
eb3a57eb17
@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
|
||||
def fixture_phi4_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_gemma2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
|
||||
|
||||
@@ -33,15 +33,14 @@ PARAMETRIZE_PARAMS = [
|
||||
"mistralv03_tokenizer_chat_template_jinja",
|
||||
"[/INST]",
|
||||
),
|
||||
# TODO: temporarily skip gemma due to gemma3 template
|
||||
# Re-enable on new chat_template implementation for perf
|
||||
# (
|
||||
# "gemma2_tokenizer",
|
||||
# "jinja",
|
||||
# "gemma2_tokenizer_chat_template_jinja",
|
||||
# "<end_of_turn>",
|
||||
# ),
|
||||
(
|
||||
"gemma2_tokenizer",
|
||||
"jinja",
|
||||
"gemma2_tokenizer_chat_template_jinja",
|
||||
"<end_of_turn>",
|
||||
),
|
||||
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
||||
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
|
||||
]
|
||||
|
||||
|
||||
@@ -95,11 +94,7 @@ class TestChatTemplateConfigurations:
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turn.get("from") in ["system", "context"]
|
||||
and (
|
||||
"mistral" in tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
|
||||
)
|
||||
and ("mistral" in tokenizer.name_or_path.lower())
|
||||
):
|
||||
assert (
|
||||
start_idx == -1 and end_idx == -1
|
||||
@@ -935,36 +930,14 @@ class TestChatTemplateConfigurations:
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
# Special case for Mistral with additional tool variables
|
||||
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
|
||||
# Most chat templates use the standard role and content variables
|
||||
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
|
||||
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
|
||||
):
|
||||
expected_variables = {"role", "content"}
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
@@ -973,6 +946,12 @@ class TestChatTemplateConfigurations:
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
assert variables == expected_variables, (
|
||||
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
|
||||
def test_eot_tokens_conflict_with_eos_token(
|
||||
self,
|
||||
tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user