diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index ae10c2800..d61e200c6 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -345,19 +345,16 @@ def save_initial_configs(
model.config.save_pretrained(str(output_dir))
-def setup_badge_for_model_card():
- """Set up the Axolotl badge for the model card."""
- badge_markdown = """[
](https://github.com/axolotl-ai-cloud/axolotl)"""
- transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
-
-
-def add_config_to_model_card(cfg: DictDefault):
+def setup_model_card(cfg: DictDefault):
"""
- Add the Axolotl configuration to the model card if available.
+ Set up the Axolotl badge and add the Axolotl config to the model card if available.
Args:
cfg: The configuration dictionary with path to axolotl config file
"""
+ badge_markdown = """[
](https://github.com/axolotl-ai-cloud/axolotl)"""
+ transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
+
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
@@ -460,8 +457,7 @@ def train(
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
- setup_badge_for_model_card()
- add_config_to_model_card(cfg)
+ setup_model_card(cfg)
# Execute the training
execute_training(cfg, trainer, resume_from_checkpoint)