From d8b63804bcc2dde359e40279bada709e25448b01 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 14 Aug 2025 01:51:13 -0400 Subject: [PATCH] cleanup --- src/axolotl/core/builders/causal.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 07bea1237..94a0ad946 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -10,6 +10,7 @@ import transformers from transformers import ( DataCollatorWithFlattening, EarlyStoppingCallback, + Trainer, ) from trl.trainer.utils import RewardDataCollatorWithPadding @@ -385,18 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): **data_collator_kwargs, ) sig = inspect.signature(trainer_cls) - - # Check if trainer class inherits from transformers.Trainer - # If so, we should pass the tokenizer/processing_class even if not in direct signature - from transformers import Trainer as HFTrainer - - if "processing_class" in sig.parameters: + if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer): trainer_kwargs["processing_class"] = self.tokenizer elif "tokenizer" in sig.parameters: trainer_kwargs["tokenizer"] = self.tokenizer - elif issubclass(trainer_cls, HFTrainer): - # For subclasses of transformers.Trainer, try processing_class first (newer HF versions) - trainer_kwargs["processing_class"] = self.tokenizer + if ( trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] and self.cfg.datasets is not None