add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)

* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
This commit is contained in:
Wing Lian
2025-08-26 09:30:04 -04:00
committed by GitHub
parent 0e9945e3b9
commit c4c4b90638
8 changed files with 100 additions and 4 deletions

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -84,5 +84,6 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -404,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -42,6 +42,7 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -63,6 +64,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -657,6 +667,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -416,7 +416,9 @@ def save_initial_configs(
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -592,6 +594,9 @@ def train(
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()

View File

@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
if cfg.bf16 and cfg.fp16 is not False:
cfg.fp16 = True
cfg.bf16 = False
else:

View File

@@ -59,6 +59,12 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
"description": "Whether to save jinja files for tokenizer, transformers default is True"
},
)
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},

View File

@@ -0,0 +1,63 @@
"""
e2e test for saving the tokenizer
"""
from unittest.mock import patch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
def test_tokenizer_no_save_jinja_files(temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_first_step": False,
"fp16": False,
"tokenizer_save_jinja_files": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
with patch("axolotl.train.execute_training"):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = f.read()
assert "chat_template" in tokenizer_config