Ignore generation/endgeneration tags when analyzing Jinja chat template (#2787)

* ignore generation/endgeneration tags

Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error.

* feat: add phi4 tokenizer test and unblock gemma2

* fix: improve template

* chore: refactor

* chore: lint

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Carsten Kragelund Jørgensen
2025-06-18 21:59:07 +02:00
committed by GitHub
parent 34da391391
commit eb3a57eb17
5 changed files with 46 additions and 49 deletions

View File

@@ -596,11 +596,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if (
turn_idx == 0
and turns[0].get("role") == "system"
and (
"mistral" in self.tokenizer.name_or_path.lower()
or "gemma"
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
)
and ("mistral" in self.tokenizer.name_or_path.lower())
):
return -1, -1

View File

@@ -3,6 +3,7 @@
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes
from jinja2.ext import Extension
class JinjaTemplateAnalysis(TypedDict):
@@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict):
iteration_target: Optional[Union[str, list[str]]]
class GenerationTagIgnore(Extension):
"""
Ignores the generation and endgeneration tags in Jinja templates.
"""
tags = {"generation", "endgeneration"}
def parse(self, parser):
parser.stream.skip(1)
return nodes.Const("")
class JinjaTemplateAnalyzer:
"""
Analyzes Jinja templates to extract information about variable usage,
@@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer:
"""
def __init__(self, template: str):
self.env: Environment = Environment(autoescape=True)
self.env: Environment = Environment(
autoescape=True, extensions=[GenerationTagIgnore]
)
self.property_access: Dict[str, Set[str]] = {}
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
self.index_access: Dict[str, Set[Union[int, float]]] = {}

File diff suppressed because one or more lines are too long

View File

@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
return tokenizer
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
def fixture_phi4_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
return tokenizer
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
def fixture_gemma2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")

View File

@@ -33,15 +33,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja",
"[/INST]",
),
# TODO: temporarily skip gemma due to gemma3 template
# Re-enable on new chat_template implementation for perf
# (
# "gemma2_tokenizer",
# "jinja",
# "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
(
"gemma2_tokenizer",
"jinja",
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
("phi35_tokenizer", "phi_35", None, "<|end|>"),
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
]
@@ -95,11 +94,7 @@ class TestChatTemplateConfigurations:
if (
turn_idx == 0
and turn.get("from") in ["system", "context"]
and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
and ("mistral" in tokenizer.name_or_path.lower())
):
assert (
start_idx == -1 and end_idx == -1
@@ -935,36 +930,14 @@ class TestChatTemplateConfigurations:
"messages",
)
if chat_template == "llama3":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "chatml":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "phi_35":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
# Special case for Mistral with additional tool variables
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
# Most chat templates use the standard role and content variables
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
):
expected_variables = {"role", "content"}
else:
LOG.warning(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
@@ -973,6 +946,12 @@ class TestChatTemplateConfigurations:
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
assert variables == expected_variables, (
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
def test_eot_tokens_conflict_with_eos_token(
self,
tokenizer,