diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a2..07753876f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC): class HFCausalTrainerBuilder(TrainerBuilderBase): """ - Build the HuggingFace training args/trainer for causal models - and reward modelling using TRL. + Build the HuggingFace training args/trainer for causal models and reward modeling + using TRL. """ def get_callbacks(self): diff --git a/src/axolotl/train.py b/src/axolotl/train.py index eb2ced231..008b27add 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -16,6 +16,7 @@ from peft import PeftConfig, PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.trainer import Trainer from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module @@ -278,7 +279,7 @@ def save_trained_model( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) -def create_model_card(cfg: DictDefault, trainer: Any): +def create_model_card(cfg: DictDefault, trainer: Trainer): """ Create a model card for the trained model if needed. @@ -287,6 +288,7 @@ def create_model_card(cfg: DictDefault, trainer: Any): trainer: The trainer object with model card creation capabilities """ if not cfg.hub_model_id: + # Guard since create_model_card may fail if dataset_tags is empty list try: model_card_kwarg = { "model_name": cfg.output_dir.lstrip("./") @@ -294,26 +296,22 @@ def create_model_card(cfg: DictDefault, trainer: Any): .decode("utf-8") } if cfg.datasets is not None: - if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model: - dataset_tags = [ - d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() - ] - dataset_tags = [ - d for d in dataset_tags if not d.startswith("https://") - ] - if dataset_tags: - # guard as create_model_card may fail if dataset_tags is empty list - model_card_kwarg["dataset_name"] = dataset_tags - else: - dataset_tags = [ - d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() - ] - dataset_tags = [ - d for d in dataset_tags if not d.startswith("https://") - ] - if dataset_tags: - # guard as create_model_card may fail if dataset_tags is empty list - model_card_kwarg["dataset_tags"] = dataset_tags + dataset_tags = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + dataset_tags = [d for d in dataset_tags if not d.startswith("https://")] + + if dataset_tags: + param_name = ( + "dataset_name" + if ( + cfg.rl is not None + or cfg.reward_model + or cfg.process_reward_model + ) + else "dataset_tags" + ) + model_card_kwarg[param_name] = dataset_tags trainer.create_model_card(**model_card_kwarg) except (AttributeError, UnicodeDecodeError):