fix: switch to using the HuggingFace Transformers NEFT implementation (#941)
* fix: switch to using the HuggingFace Transformers NEFT implementation * linter * add support for noisy_embedding_alpha with a warning about it being renamed * restore pre/posttrain_hooks * move validation of NEFT noise alpha into validate_config() * linter
This commit is contained in:
@@ -16,7 +16,6 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch import neft_embeddings
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
@@ -180,21 +179,19 @@ def train(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(cfg, trainer):
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||
|
||||
|
||||
def post_train_hooks(cfg, trainer):
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.post_train_hook(cfg, trainer)
|
||||
|
||||
Reference in New Issue
Block a user