Ignore generation/endgeneration tags when analyzing Jinja chat template (#2787)
* ignore generation/endgeneration tags Axolotl handles calculating the mask for assistant turns on its own, and as such these tags are not needed, however currently the analyzer does not recognize them at all and throws an error. * feat: add phi4 tokenizer test and unblock gemma2 * fix: improve template * chore: refactor * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
committed by
GitHub
parent
34da391391
commit
eb3a57eb17
@@ -596,11 +596,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
|
||||
)
|
||||
and ("mistral" in self.tokenizer.name_or_path.lower())
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
from jinja2.ext import Extension
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
@@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict):
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class GenerationTagIgnore(Extension):
|
||||
"""
|
||||
Ignores the generation and endgeneration tags in Jinja templates.
|
||||
"""
|
||||
|
||||
tags = {"generation", "endgeneration"}
|
||||
|
||||
def parse(self, parser):
|
||||
parser.stream.skip(1)
|
||||
return nodes.Const("")
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
@@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.env: Environment = Environment(
|
||||
autoescape=True, extensions=[GenerationTagIgnore]
|
||||
)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user