Grokfast support (#1917)
This commit is contained in:
@@ -1287,6 +1287,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
[
|
||||||
|
cb
|
||||||
|
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||||
|
self.cfg, trainer
|
||||||
|
)
|
||||||
|
if cb
|
||||||
|
]
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
setup callbacks before creating the trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
@@ -155,14 +155,15 @@ class BasePlugin:
|
|||||||
self, cfg, trainer
|
self, cfg, trainer
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after creating the trainer.
|
||||||
|
This is useful for callbacks that require access to the model or trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
trainer (object): The trainer object for training.
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -393,7 +394,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||||
|
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
@@ -409,7 +412,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||||
|
if plugin_callbacks:
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
|
|||||||
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
13
src/axolotl/integrations/grokfast/README.md
Normal file
13
src/axolotl/integrations/grokfast/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Grokfast Optimizer
|
||||||
|
|
||||||
|
See https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||||
|
|
||||||
|
grokfast_alpha: 2.0
|
||||||
|
grokfast_lamb: 0.98
|
||||||
|
```
|
||||||
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Grokfast plugin for Axolotl
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .optimizer import gradfilter_ema
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastCallbackHandler(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Transformer trainer callbacks for Grokfast
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||||
|
super().__init__(*args_, **kwargs)
|
||||||
|
self.grads = None
|
||||||
|
self.alpha = alpha
|
||||||
|
self.lamb = lamb
|
||||||
|
|
||||||
|
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||||
|
self.grads = None
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(
|
||||||
|
self, args_, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
LOG.info("Adding Grokfast callback to the trainer")
|
||||||
|
callback = GrokfastCallbackHandler(
|
||||||
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
15
src/axolotl/integrations/grokfast/args.py
Normal file
15
src/axolotl/integrations/grokfast/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
config args for grokfast plugin
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Grokfast optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grokfast_alpha: Optional[float] = 0.98
|
||||||
|
grokfast_lamb: Optional[float] = 2.0
|
||||||
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
# Reference: https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ma(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, deque]] = None,
|
||||||
|
window_size: int = 100,
|
||||||
|
lamb: float = 5.0,
|
||||||
|
filter_type: Literal["mean", "sum"] = "mean",
|
||||||
|
warmup: bool = True,
|
||||||
|
trigger: bool = False, # For ablation study.
|
||||||
|
) -> Dict[str, deque]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: deque(maxlen=window_size)
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||||
|
|
||||||
|
# Modify the gradients.
|
||||||
|
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||||
|
if filter_type == "mean":
|
||||||
|
avg = sum(grads[n]) / len(grads[n])
|
||||||
|
elif filter_type == "sum":
|
||||||
|
avg = sum(grads[n])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||||
|
p.grad.data = p.grad.data + avg * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ema(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
alpha: float = 0.98,
|
||||||
|
lamb: float = 2.0,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: p.grad.data.detach()
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||||
|
p.grad.data = p.grad.data + grads[n] * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
@@ -783,6 +783,8 @@ class AxolotlInputConfig(
|
|||||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||||
|
|
||||||
|
plugins: Optional[List[str]] = Field(default=None)
|
||||||
|
|
||||||
@field_validator("datasets", mode="before")
|
@field_validator("datasets", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def deprecate_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
|
|||||||
Reference in New Issue
Block a user