create a model card with axolotl badge (#624)

This commit is contained in:
Wing Lian
2023-09-22 16:13:26 -04:00
committed by GitHub
parent c25ba7939b
commit 501958bb6f

View File

@@ -9,8 +9,7 @@ from pathlib import Path
from typing import Optional
import torch
# add src to the pythonpath so we don't need to pip install this
import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
@@ -103,6 +102,9 @@ def train(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
)
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
@@ -138,4 +140,7 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer