Ignore generation/endgeneration tags when analyzing Jinja chat template (#2787)
* ignore generation/endgeneration tags Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error. * feat: add phi4 tokenizer test and unblock gemma2 * fix: improve template * chore: refactor * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
committed by
GitHub
parent
34da391391
commit
eb3a57eb17
@@ -596,11 +596,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
|
||||
)
|
||||
and ("mistral" in self.tokenizer.name_or_path.lower())
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
from jinja2.ext import Extension
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
@@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict):
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class GenerationTagIgnore(Extension):
|
||||
"""
|
||||
Ignores the generation and endgeneration tags in Jinja templates.
|
||||
"""
|
||||
|
||||
tags = {"generation", "endgeneration"}
|
||||
|
||||
def parse(self, parser):
|
||||
parser.stream.skip(1)
|
||||
return nodes.Const("")
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
@@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.env: Environment = Environment(
|
||||
autoescape=True, extensions=[GenerationTagIgnore]
|
||||
)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
|
||||
def fixture_phi4_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_gemma2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
|
||||
|
||||
@@ -33,15 +33,14 @@ PARAMETRIZE_PARAMS = [
|
||||
"mistralv03_tokenizer_chat_template_jinja",
|
||||
"[/INST]",
|
||||
),
|
||||
# TODO: temporarily skip gemma due to gemma3 template
|
||||
# Re-enable on new chat_template implementation for perf
|
||||
# (
|
||||
# "gemma2_tokenizer",
|
||||
# "jinja",
|
||||
# "gemma2_tokenizer_chat_template_jinja",
|
||||
# "<end_of_turn>",
|
||||
# ),
|
||||
(
|
||||
"gemma2_tokenizer",
|
||||
"jinja",
|
||||
"gemma2_tokenizer_chat_template_jinja",
|
||||
"<end_of_turn>",
|
||||
),
|
||||
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
||||
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
|
||||
]
|
||||
|
||||
|
||||
@@ -95,11 +94,7 @@ class TestChatTemplateConfigurations:
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turn.get("from") in ["system", "context"]
|
||||
and (
|
||||
"mistral" in tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
|
||||
)
|
||||
and ("mistral" in tokenizer.name_or_path.lower())
|
||||
):
|
||||
assert (
|
||||
start_idx == -1 and end_idx == -1
|
||||
@@ -935,36 +930,14 @@ class TestChatTemplateConfigurations:
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
# Special case for Mistral with additional tool variables
|
||||
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
|
||||
# Most chat templates use the standard role and content variables
|
||||
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
|
||||
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
|
||||
):
|
||||
expected_variables = {"role", "content"}
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
@@ -973,6 +946,12 @@ class TestChatTemplateConfigurations:
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
assert variables == expected_variables, (
|
||||
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
|
||||
def test_eot_tokens_conflict_with_eos_token(
|
||||
self,
|
||||
tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user