diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5ed5837f2..da98600a4 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -9,8 +9,7 @@ from pathlib import Path from typing import Optional import torch - -# add src to the pythonpath so we don't need to pip install this +import transformers.modelcard from datasets import Dataset from optimum.bettertransformer import BetterTransformer @@ -103,6 +102,9 @@ def train( signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) + badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") @@ -138,4 +140,7 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + if not cfg.hub_model_id: + trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) + return model, tokenizer