diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index cf78b7186..0900ca7f1 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -48,17 +48,17 @@ with suppress(ImportError): class TrainerBuilderBase(abc.ABC): """Base class for trainer builder.""" - _train_dataset = None - _eval_dataset = None - _model_ref = None - _peft_config = None - def __init__(self, cfg, model, tokenizer, processor=None): self.cfg = cfg self.model = model self.tokenizer = tokenizer self.processor = processor + self._train_dataset = None + self._eval_dataset = None + self._model_ref = None + self._peft_config = None + # If the model supports tagging, add the axolotl tag. # This makes sure the tag is correctly pushed even if a user calls # model.push_to_hub instead of trainer.push_to_hub.