This commit is contained in:
Dan Saunders
2025-02-26 20:31:54 +00:00
parent c4104fc10c
commit a3224c7c3c
2 changed files with 21 additions and 23 deletions

View File

@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
Build the HuggingFace training args/trainer for causal models and reward modeling
using TRL.
"""
def get_callbacks(self):

View File

@@ -16,6 +16,7 @@ from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
@@ -278,7 +279,7 @@ def save_trained_model(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Any):
def create_model_card(cfg: DictDefault, trainer: Trainer):
"""
Create a model card for the trained model if needed.
@@ -287,6 +288,7 @@ def create_model_card(cfg: DictDefault, trainer: Any):
trainer: The trainer object with model card creation capabilities
"""
if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try:
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
@@ -294,26 +296,22 @@ def create_model_card(cfg: DictDefault, trainer: Any):
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_name"] = dataset_tags
else:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://")
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
if dataset_tags:
param_name = (
"dataset_name"
if (
cfg.rl is not None
or cfg.reward_model
or cfg.process_reward_model
)
else "dataset_tags"
)
model_card_kwarg[param_name] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):