cleanup
This commit is contained in:
@@ -10,6 +10,7 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
Trainer,
|
||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
@@ -385,18 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
)
|
)
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
|
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
|
||||||
# Check if trainer class inherits from transformers.Trainer
|
|
||||||
# If so, we should pass the tokenizer/processing_class even if not in direct signature
|
|
||||||
from transformers import Trainer as HFTrainer
|
|
||||||
|
|
||||||
if "processing_class" in sig.parameters:
|
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
elif "tokenizer" in sig.parameters:
|
elif "tokenizer" in sig.parameters:
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
elif issubclass(trainer_cls, HFTrainer):
|
|
||||||
# For subclasses of transformers.Trainer, try processing_class first (newer HF versions)
|
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
|
||||||
if (
|
if (
|
||||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||||
and self.cfg.datasets is not None
|
and self.cfg.datasets is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user