This commit is contained in:
Dan Saunders
2025-08-14 01:51:13 -04:00
parent 3156c605d4
commit d8b63804bc

View File

@@ -10,6 +10,7 @@ import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
@@ -385,18 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
# Check if trainer class inherits from transformers.Trainer
# If so, we should pass the tokenizer/processing_class even if not in direct signature
from transformers import Trainer as HFTrainer
if "processing_class" in sig.parameters:
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
trainer_kwargs["processing_class"] = self.tokenizer
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
elif issubclass(trainer_cls, HFTrainer):
# For subclasses of transformers.Trainer, try processing_class first (newer HF versions)
trainer_kwargs["processing_class"] = self.tokenizer
if (
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None