This commit is contained in:
Dan Saunders
2025-02-26 20:31:54 +00:00
parent c4104fc10c
commit a3224c7c3c
2 changed files with 21 additions and 23 deletions

View File

@@ -227,8 +227,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase): class HFCausalTrainerBuilder(TrainerBuilderBase):
""" """
Build the HuggingFace training args/trainer for causal models Build the HuggingFace training args/trainer for causal models and reward modeling
and reward modelling using TRL. using TRL.
""" """
def get_callbacks(self): def get_callbacks(self):

View File

@@ -16,6 +16,7 @@ from peft import PeftConfig, PeftModel
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
@@ -278,7 +279,7 @@ def save_trained_model(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def create_model_card(cfg: DictDefault, trainer: Any): def create_model_card(cfg: DictDefault, trainer: Trainer):
""" """
Create a model card for the trained model if needed. Create a model card for the trained model if needed.
@@ -287,6 +288,7 @@ def create_model_card(cfg: DictDefault, trainer: Any):
trainer: The trainer object with model card creation capabilities trainer: The trainer object with model card creation capabilities
""" """
if not cfg.hub_model_id: if not cfg.hub_model_id:
# Guard since create_model_card may fail if dataset_tags is empty list
try: try:
model_card_kwarg = { model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./") "model_name": cfg.output_dir.lstrip("./")
@@ -294,26 +296,22 @@ def create_model_card(cfg: DictDefault, trainer: Any):
.decode("utf-8") .decode("utf-8")
} }
if cfg.datasets is not None: if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model: dataset_tags = [
dataset_tags = [ d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() ]
] dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
dataset_tags = [
d for d in dataset_tags if not d.startswith("https://") if dataset_tags:
] param_name = (
if dataset_tags: "dataset_name"
# guard as create_model_card may fail if dataset_tags is empty list if (
model_card_kwarg["dataset_name"] = dataset_tags cfg.rl is not None
else: or cfg.reward_model
dataset_tags = [ or cfg.process_reward_model
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir() )
] else "dataset_tags"
dataset_tags = [ )
d for d in dataset_tags if not d.startswith("https://") model_card_kwarg[param_name] = dataset_tags
]
if dataset_tags:
# guard as create_model_card may fail if dataset_tags is empty list
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg) trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError): except (AttributeError, UnicodeDecodeError):