diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 9b03563e0..93384189e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): return kwargs +def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags + + return kwargs + + @dataclass class AxolotlTrainingMixins: """ @@ -418,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer): *_args, bench_data_collator=None, eval_data_collator=None, + dataset_tags=None, **kwargs, ): self.bench_data_collator = bench_data_collator self.eval_data_collator = eval_data_collator + self.dataset_tags = dataset_tags super().__init__(*_args, **kwargs) self.train_data_collator = self.data_collator self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -919,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -1042,8 +1063,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): tag_names = ["axolotl", "dpo"] - def __init__(self, *args, **kwargs): + def __init__(self, *args, dataset_tags=None, **kwargs): super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags self.optimizer = None def create_optimizer(self): @@ -1082,6 +1104,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): Overwrite the `push_to_hub` method in order to force-add the tags when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) return super().push_to_hub(*args, **kwargs) @@ -1806,6 +1831,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: trainer_kwargs["tokenizer"] = self.tokenizer + if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None: + trainer_kwargs["dataset_tags"] = [ + d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() + ] trainer = trainer_cls( model=self.model, train_dataset=self.train_dataset, @@ -2079,6 +2108,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: dpo_trainer_kwargs["tokenizer"] = self.tokenizer + if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer): + dpo_trainer_kwargs["dataset_tags"] = [ + d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() + ] dpo_trainer = trainer_cls( *trainer_cls_args, args=training_args, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5fde4d384..39af9f45c 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -259,11 +259,31 @@ def train( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: + from huggingface_hub import HfApi + from huggingface_hub.utils import RepositoryNotFoundError + try: - trainer.create_model_card( - model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") - ) - except (AttributeError, UnicodeDecodeError): + # Check to make sure the base model is from HuggingFace not a local directory + hf_api = HfApi() + hf_api.model_info(cfg.base_model) + + model_card_kwarg = { + "model_name": cfg.output_dir.lstrip("./") + .encode("utf-8") + .decode("utf-8") + } + if cfg.datasets is not None: + if cfg.rl is not None or cfg.reward_model: + model_card_kwarg["dataset_name"] = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + else: + model_card_kwarg["dataset_tags"] = [ + d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() + ] + + trainer.create_model_card(**model_card_kwarg) + except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError): pass elif cfg.hub_model_id: # defensively push to the hub to ensure the model card is updated