50 lines
1.3 KiB
Python
50 lines
1.3 KiB
Python
"""
|
|
Grokfast plugin for Axolotl
|
|
"""
|
|
|
|
from transformers.trainer_callback import TrainerCallback
|
|
|
|
from axolotl.utils.logging import get_logger
|
|
|
|
from ..base import BasePlugin
|
|
from .args import GrokfastArgs as GrokfastArgs
|
|
from .optimizer import gradfilter_ema
|
|
|
|
LOG = get_logger(__name__)
|
|
|
|
|
|
class GrokfastCallbackHandler(TrainerCallback):
|
|
"""
|
|
Transformer trainer callbacks for Grokfast
|
|
"""
|
|
|
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
|
super().__init__(*args_, **kwargs)
|
|
self.grads = None
|
|
self.alpha = alpha
|
|
self.lamb = lamb
|
|
|
|
def on_train_begin(self, *args_, **kwargs):
|
|
self.grads = None
|
|
|
|
def on_pre_optimizer_step(self, args_, state, control, **kwargs):
|
|
model = kwargs.pop("model")
|
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
|
return control
|
|
|
|
|
|
class GrokfastPlugin(BasePlugin):
|
|
"""
|
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
|
"""
|
|
|
|
def get_input_args(self):
|
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
|
|
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
|
LOG.info("Adding Grokfast callback to the trainer")
|
|
callback = GrokfastCallbackHandler(
|
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
|
)
|
|
return [callback]
|