add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json * fix test import
This commit is contained in:
@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir) / "merged"),
|
||||||
|
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||||
|
)
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|||||||
@@ -84,5 +84,6 @@ def do_quantize(
|
|||||||
str(Path(output_dir) / "quantized"),
|
str(Path(output_dir) / "quantized"),
|
||||||
safe_serialization=False,
|
safe_serialization=False,
|
||||||
progressbar=True,
|
progressbar=True,
|
||||||
|
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||||
)
|
)
|
||||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||||
|
|||||||
@@ -404,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
trainer = self.hook_post_create_trainer(trainer)
|
trainer = self.hook_post_create_trainer(trainer)
|
||||||
|
# if the trainer has the `axolotl_cfg` property, set it
|
||||||
|
if hasattr(trainer, "axolotl_cfg"):
|
||||||
|
trainer.axolotl_cfg = self.cfg
|
||||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from axolotl.core.trainers.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils import get_not_null
|
from axolotl.utils import get_not_null
|
||||||
from axolotl.utils.bench import get_gpu_memory_usage
|
from axolotl.utils.bench import get_gpu_memory_usage
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -63,6 +64,15 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
|
_axolotl_cfg: DictDefault | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def axolotl_cfg(self):
|
||||||
|
return self._axolotl_cfg
|
||||||
|
|
||||||
|
@axolotl_cfg.setter
|
||||||
|
def axolotl_cfg(self, cfg):
|
||||||
|
self._axolotl_cfg = cfg
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -657,6 +667,11 @@ class AxolotlTrainer(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||||
)
|
)
|
||||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
save_jinja_files = True
|
||||||
|
if self.axolotl_cfg:
|
||||||
|
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||||
|
self.data_collator.tokenizer.save_pretrained(
|
||||||
|
output_dir, save_jinja_files=save_jinja_files
|
||||||
|
)
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
@@ -416,7 +416,9 @@ def save_initial_configs(
|
|||||||
|
|
||||||
# Pre-save the tokenizer and model configs
|
# Pre-save the tokenizer and model configs
|
||||||
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
||||||
tokenizer.save_pretrained(str(output_dir))
|
tokenizer.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
|
||||||
|
)
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
@@ -592,6 +594,9 @@ def train(
|
|||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
tokenizer.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
|
||||||
|
)
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
if not cfg.use_ray:
|
if not cfg.use_ray:
|
||||||
cleanup_distributed()
|
cleanup_distributed()
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
|
|||||||
if cfg.device == "mps":
|
if cfg.device == "mps":
|
||||||
cfg.load_in_8bit = False
|
cfg.load_in_8bit = False
|
||||||
cfg.tf32 = False
|
cfg.tf32 = False
|
||||||
if cfg.bf16:
|
if cfg.bf16 and cfg.fp16 is not False:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -59,6 +59,12 @@ class ModelInputConfig(BaseModel):
|
|||||||
processor_type: str | None = Field(
|
processor_type: str | None = Field(
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
|
tokenizer_save_jinja_files: bool | None = Field(
|
||||||
|
default=True, # match the default behavior from transformers
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to save jinja files for tokenizer, transformers default is True"
|
||||||
|
},
|
||||||
|
)
|
||||||
trust_remote_code: bool | None = Field(
|
trust_remote_code: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||||
|
|||||||
63
tests/e2e/test_tokenizer.py
Normal file
63
tests/e2e/test_tokenizer.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
e2e test for saving the tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_model_output_exists
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer_no_save_jinja_files(temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.02,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_first_step": False,
|
||||||
|
"fp16": False,
|
||||||
|
"tokenizer_save_jinja_files": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
with patch("axolotl.train.execute_training"):
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||||
|
tokenizer_config = f.read()
|
||||||
|
assert "chat_template" in tokenizer_config
|
||||||
Reference in New Issue
Block a user