From 85dd4d525b9e65740ccb48d3c3897d35c9ae5265 Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Wed, 27 Dec 2023 19:25:33 -0800 Subject: [PATCH] add config to model card (#1005) * add config to model card * rm space * apply black formatting * apply black formatting * fix formatting * check for cfg attribute * add version * add version * put the config in a collapsible element * put the config in a collapsible element --- src/axolotl/train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 169fc5127..4e5241e4c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import transformers.modelcard from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from pkg_resources import get_distribution # type: ignore from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -115,6 +116,12 @@ def train( badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = get_distribution("axolotl").version + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length")