Add ds model card, rebased (#2101) [skip ci]
* rebased add_ds_model_card * manual rebasing * fix redundancy * lint * include case when ds_tag is none * conform to kwargs in create_model_card
This commit is contained in:
@@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""
|
||||||
@@ -418,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
*_args,
|
*_args,
|
||||||
bench_data_collator=None,
|
bench_data_collator=None,
|
||||||
eval_data_collator=None,
|
eval_data_collator=None,
|
||||||
|
dataset_tags=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
@@ -919,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
@@ -1042,8 +1063,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
@@ -1082,6 +1104,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
@@ -1806,6 +1831,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
|
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
|
||||||
|
trainer_kwargs["dataset_tags"] = [
|
||||||
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
@@ -2079,6 +2108,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
|
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||||
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|||||||
@@ -259,11 +259,31 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.create_model_card(
|
# Check to make sure the base model is from HuggingFace not a local directory
|
||||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
hf_api = HfApi()
|
||||||
)
|
hf_api.model_info(cfg.base_model)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
|
||||||
|
model_card_kwarg = {
|
||||||
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
|
.encode("utf-8")
|
||||||
|
.decode("utf-8")
|
||||||
|
}
|
||||||
|
if cfg.datasets is not None:
|
||||||
|
if cfg.rl is not None or cfg.reward_model:
|
||||||
|
model_card_kwarg["dataset_name"] = [
|
||||||
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_card_kwarg["dataset_tags"] = [
|
||||||
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
|
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
Reference in New Issue
Block a user