From c4c4b906382fed7ec3b3dfb7ef2d6f4734962c60 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Aug 2025 09:30:04 -0400 Subject: [PATCH] add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093) * add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json * fix test import --- src/axolotl/cli/merge_lora.py | 5 ++- src/axolotl/cli/quantize.py | 1 + src/axolotl/core/builders/causal.py | 3 ++ src/axolotl/core/trainers/base.py | 17 +++++++- src/axolotl/train.py | 7 +++- src/axolotl/utils/config/__init__.py | 2 +- src/axolotl/utils/schemas/model.py | 6 +++ tests/e2e/test_tokenizer.py | 63 ++++++++++++++++++++++++++++ 8 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 tests/e2e/test_tokenizer.py diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 31fad1b29..657ddcfe4 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None: safe_serialization=safe_serialization, progressbar=True, ) - tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) + tokenizer.save_pretrained( + str(Path(cfg.output_dir) / "merged"), + save_jinja_files=cfg.tokenizer_save_jinja_files, + ) if processor: processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 0782976fe..b8a8de781 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -84,5 +84,6 @@ def do_quantize( str(Path(output_dir) / "quantized"), safe_serialization=False, progressbar=True, + save_jinja_files=cfg.tokenizer_save_jinja_files, ) LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...") diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 94b0db851..e5bc68c39 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -404,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **trainer_kwargs, ) trainer = self.hook_post_create_trainer(trainer) + # if the trainer has the `axolotl_cfg` property, set it + if hasattr(trainer, "axolotl_cfg"): + trainer.axolotl_cfg = self.cfg for callback in self.get_post_trainer_create_callbacks(trainer): trainer.add_callback(callback) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 4b8861790..f707d4b5a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -42,6 +42,7 @@ from axolotl.core.trainers.utils import ( ) from axolotl.utils import get_not_null from axolotl.utils.bench import get_gpu_memory_usage +from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -63,6 +64,15 @@ class AxolotlTrainer( args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] tag_names = ["axolotl"] + _axolotl_cfg: DictDefault | None = None + + @property + def axolotl_cfg(self): + return self._axolotl_cfg + + @axolotl_cfg.setter + def axolotl_cfg(self, cfg): + self._axolotl_cfg = cfg def __init__( self, @@ -657,6 +667,11 @@ class AxolotlTrainer( LOG.info( "Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`" ) - self.data_collator.tokenizer.save_pretrained(output_dir) + save_jinja_files = True + if self.axolotl_cfg: + save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files + self.data_collator.tokenizer.save_pretrained( + output_dir, save_jinja_files=save_jinja_files + ) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e409d4a11..e8e314579 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -416,7 +416,9 @@ def save_initial_configs( # Pre-save the tokenizer and model configs LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...") - tokenizer.save_pretrained(str(output_dir)) + tokenizer.save_pretrained( + str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files + ) if hasattr(model, "config"): LOG.info(f"Pre-saving model config to {cfg.output_dir}...") model.config.save_pretrained(str(output_dir)) @@ -592,6 +594,9 @@ def train( # Save the trained model and cleanup save_trained_model(cfg, trainer, model, safe_serialization) + tokenizer.save_pretrained( + str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files + ) create_model_card(cfg, trainer) if not cfg.use_ray: cleanup_distributed() diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 534d7c4a4..2b6ef8d98 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -77,7 +77,7 @@ def resolve_dtype(cfg): if cfg.device == "mps": cfg.load_in_8bit = False cfg.tf32 = False - if cfg.bf16: + if cfg.bf16 and cfg.fp16 is not False: cfg.fp16 = True cfg.bf16 = False else: diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index eb751bfcc..56b206b51 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -59,6 +59,12 @@ class ModelInputConfig(BaseModel): processor_type: str | None = Field( default=None, json_schema_extra={"description": "transformers processor class"} ) + tokenizer_save_jinja_files: bool | None = Field( + default=True, # match the default behavior from transformers + json_schema_extra={ + "description": "Whether to save jinja files for tokenizer, transformers default is True" + }, + ) trust_remote_code: bool | None = Field( default=None, json_schema_extra={"description": "Trust remote code for untrusted source"}, diff --git a/tests/e2e/test_tokenizer.py b/tests/e2e/test_tokenizer.py new file mode 100644 index 000000000..a65c17ac3 --- /dev/null +++ b/tests/e2e/test_tokenizer.py @@ -0,0 +1,63 @@ +""" +e2e test for saving the tokenizer +""" + +from unittest.mock import patch + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_model_output_exists + + +def test_tokenizer_no_save_jinja_files(temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_first_step": False, + "fp16": False, + "tokenizer_save_jinja_files": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + with patch("axolotl.train.execute_training"): + train(cfg=cfg, dataset_meta=dataset_meta) + + check_model_output_exists(temp_dir, cfg) + with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = f.read() + assert "chat_template" in tokenizer_config