add chat_template_jinja to wandb (#3192) [skip ci]
* add chat_template_jinja to wandb * temp_ct_file.flush() * Update src/axolotl/utils/callbacks/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Update src/axolotl/utils/callbacks/__init__.py Co-authored-by: Wing Lian <wing.lian@gmail.com> * Apply suggestion from @winglian --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import wandb
|
import wandb
|
||||||
|
import yaml
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -759,6 +760,37 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.axolotl_config_path, "r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
chat_tpl = cfg.get("chat_template_jinja")
|
||||||
|
if chat_tpl:
|
||||||
|
with NamedTemporaryFile(
|
||||||
|
mode="w", delete=True, suffix=".jinja", prefix="chat_template_"
|
||||||
|
) as temp_ct_file:
|
||||||
|
if (
|
||||||
|
isinstance(chat_tpl, str)
|
||||||
|
and os.path.exists(chat_tpl)
|
||||||
|
and os.path.isfile(chat_tpl)
|
||||||
|
):
|
||||||
|
copyfile(chat_tpl, temp_ct_file.name)
|
||||||
|
else:
|
||||||
|
temp_ct_file.write(str(chat_tpl))
|
||||||
|
temp_ct_file.flush()
|
||||||
|
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
f"chat-template-{wandb.run.id}", type="jinja-template"
|
||||||
|
)
|
||||||
|
artifact.add_file(temp_ct_file.name)
|
||||||
|
wandb.log_artifact(artifact)
|
||||||
|
wandb.save(temp_ct_file.name)
|
||||||
|
LOG.info(
|
||||||
|
"The chat_template_jinja has been saved to the WandB run under files."
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ConnectionError, yaml.YAMLError) as err:
|
||||||
|
LOG.warning(f"Error while saving chat_template_jinja to WandB: {err}")
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
try:
|
try:
|
||||||
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||||
|
|||||||
Reference in New Issue
Block a user