add chat_template_jinja to wandb (#3192) [skip ci]

* add chat_template_jinja to wandb

* temp_ct_file.flush()

* Update src/axolotl/utils/callbacks/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update src/axolotl/utils/callbacks/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Apply suggestion from @winglian

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
VED
2025-10-09 21:35:54 +05:30
committed by GitHub
parent ab63b92c38
commit 37f78c8592

View File

@@ -16,6 +16,7 @@ import pandas as pd
import torch
import torch.distributed as dist
import wandb
import yaml
from datasets import load_dataset
from tqdm import tqdm
from transformers import (
@@ -759,6 +760,37 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
try:
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
chat_tpl = cfg.get("chat_template_jinja")
if chat_tpl:
with NamedTemporaryFile(
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
) as temp_ct_file:
if (
isinstance(chat_tpl, str)
and os.path.exists(chat_tpl)
and os.path.isfile(chat_tpl)
):
copyfile(chat_tpl, temp_ct_file.name)
else:
temp_ct_file.write(str(chat_tpl))
temp_ct_file.flush()
artifact = wandb.Artifact(
f"chat-template-{wandb.run.id}", type="jinja-template"
)
artifact.add_file(temp_ct_file.name)
wandb.log_artifact(artifact)
wandb.save(temp_ct_file.name)
LOG.info(
"The chat_template_jinja has been saved to the WandB run under files."
)
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
if args.deepspeed:
try:
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.