fix: address TrainerBuilderBase class variables to instance var
This commit is contained in:
@@ -48,17 +48,17 @@ with suppress(ImportError):
|
|||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""Base class for trainer builder."""
|
"""Base class for trainer builder."""
|
||||||
|
|
||||||
_train_dataset = None
|
|
||||||
_eval_dataset = None
|
|
||||||
_model_ref = None
|
|
||||||
_peft_config = None
|
|
||||||
|
|
||||||
def __init__(self, cfg, model, tokenizer, processor=None):
|
def __init__(self, cfg, model, tokenizer, processor=None):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
|
self._train_dataset = None
|
||||||
|
self._eval_dataset = None
|
||||||
|
self._model_ref = None
|
||||||
|
self._peft_config = None
|
||||||
|
|
||||||
# If the model supports tagging, add the axolotl tag.
|
# If the model supports tagging, add the axolotl tag.
|
||||||
# This makes sure the tag is correctly pushed even if a user calls
|
# This makes sure the tag is correctly pushed even if a user calls
|
||||||
# model.push_to_hub instead of trainer.push_to_hub.
|
# model.push_to_hub instead of trainer.push_to_hub.
|
||||||
|
|||||||
Reference in New Issue
Block a user