create a model card with axolotl badge (#624)
This commit is contained in:
@@ -9,8 +9,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
import transformers.modelcard
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
@@ -103,6 +102,9 @@ def train(
|
||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
@@ -138,4 +140,7 @@ def train(
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
Reference in New Issue
Block a user