diff --git a/src/axolotl/train.py b/src/axolotl/train.py index ae10c2800..d61e200c6 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -345,19 +345,16 @@ def save_initial_configs( model.config.save_pretrained(str(output_dir)) -def setup_badge_for_model_card(): - """Set up the Axolotl badge for the model card.""" - badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" - transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" - - -def add_config_to_model_card(cfg: DictDefault): +def setup_model_card(cfg: DictDefault): """ - Add the Axolotl configuration to the model card if available. + Set up the Axolotl badge and add the Axolotl config to the model card if available. Args: cfg: The configuration dictionary with path to axolotl config file """ + badge_markdown = """[Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)""" + transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + if getattr(cfg, "axolotl_config_path"): raw_axolotl_cfg = Path(cfg.axolotl_config_path) version = get_distribution("axolotl").version @@ -460,8 +457,7 @@ def train( setup_signal_handler(cfg, model, safe_serialization) # Set up badges and config info for model card - setup_badge_for_model_card() - add_config_to_model_card(cfg) + setup_model_card(cfg) # Execute the training execute_training(cfg, trainer, resume_from_checkpoint)