diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 169fc5127..4e5241e4c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,6 +12,7 @@ import transformers.modelcard from accelerate.logging import get_logger from datasets import Dataset from optimum.bettertransformer import BetterTransformer +from pkg_resources import get_distribution # type: ignore from transformers.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs @@ -115,6 +116,12 @@ def train( badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + if getattr(cfg, "axolotl_config_path"): + raw_axolotl_cfg = Path(cfg.axolotl_config_path) + version = get_distribution("axolotl").version + if raw_axolotl_cfg.is_file(): + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n
See axolotl config\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n

\n" + LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length")