51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
"""
|
|
Grokfast plugin for Axolotl
|
|
"""
|
|
import logging
|
|
|
|
from transformers.trainer_callback import TrainerCallback
|
|
|
|
from ..base import BasePlugin
|
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
|
from .optimizer import gradfilter_ema
|
|
|
|
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
|
|
|
|
|
class GrokfastCallbackHandler(TrainerCallback):
|
|
"""
|
|
Transformer trainer callbacks for Grokfast
|
|
"""
|
|
|
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
|
super().__init__(*args_, **kwargs)
|
|
self.grads = None
|
|
self.alpha = alpha
|
|
self.lamb = lamb
|
|
|
|
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
|
self.grads = None
|
|
|
|
def on_pre_optimizer_step(
|
|
self, args_, state, control, **kwargs
|
|
): # pylint: disable=unused-argument
|
|
model = kwargs.pop("model")
|
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
|
return control
|
|
|
|
|
|
class GrokfastPlugin(BasePlugin):
|
|
"""
|
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
|
"""
|
|
|
|
def get_input_args(self):
|
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
|
|
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
|
LOG.info("Adding Grokfast callback to the trainer")
|
|
callback = GrokfastCallbackHandler(
|
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
|
)
|
|
return [callback]
|