fix: switch to using the HuggingFace Transformers NEFT implementation (#941)
* fix: switch to using the HuggingFace Transformers NEFT implementation * linter * add support for noisy_embedding_alpha with a warning about it being renamed * restore pre/posttrain_hooks * move validation of NEFT noise alpha into validate_config() * linter
This commit is contained in:
@@ -774,7 +774,7 @@ max_grad_norm:
|
|||||||
# Augmentation techniques
|
# Augmentation techniques
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
noisy_embedding_alpha:
|
neftune_noise_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
|
|||||||
@@ -712,6 +712,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
|
|
||||||
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"neftune_noise_alpha"
|
||||||
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
"""
|
|
||||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
def patch_neft(alpha, model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
embeddings.noisy_embedding_alpha = alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
|
||||||
embeddings, embeddings.__class__
|
|
||||||
)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def unpatch_neft(model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
if hasattr(embeddings, "_old_forward"):
|
|
||||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings.noisy_embedding_alpha
|
|
||||||
|
|
||||||
|
|
||||||
def neft_forward(self, inputs: torch.Tensor):
|
|
||||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
|
||||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
unpatch_neft(trainer.model)
|
|
||||||
@@ -16,7 +16,6 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.monkeypatch import neft_embeddings
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_parameters_except
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -180,21 +179,19 @@ def train(
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hooks(cfg, trainer):
|
def pretrain_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right before kicking off the training
|
Run hooks right before kicking off the training
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hooks(cfg, trainer):
|
def post_train_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right after training completes
|
Run hooks right after training completes
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.post_train_hook(cfg, trainer)
|
|
||||||
|
|||||||
@@ -434,6 +434,20 @@ def validate_config(cfg):
|
|||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.noisy_embedding_alpha is not None:
|
||||||
|
# Deprecated, use neftune_noise_alpha
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
if cfg.neftune_noise_alpha is None:
|
||||||
|
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||||
|
else:
|
||||||
|
# User is providing both; bail and have them sort out their settings
|
||||||
|
raise ValueError(
|
||||||
|
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
Reference in New Issue
Block a user