improve lisa callback logging

This commit is contained in:
Aman Karmani
2024-04-01 04:54:00 +00:00
parent 21a5094226
commit b357c93f23

View File

@@ -6,6 +6,7 @@ Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0 License: Apache 2.0
""" """
import logging
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -15,6 +16,8 @@ from transformers import TrainerCallback
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainer from axolotl.core.trainer_builder import AxolotlTrainer
LOG = logging.getLogger("axolotl.callbacks")
def lisa_callback_factory(trainer: "AxolotlTrainer"): def lisa_callback_factory(trainer: "AxolotlTrainer"):
class LISACallback(TrainerCallback): class LISACallback(TrainerCallback):
@@ -34,16 +37,20 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.total_layers = len( self.total_layers = len(
reduce(getattr, self.layers_attribute.split("."), self.trainer.model) reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
) )
self.freeze_all_layers() self.freeze_all_layers(True)
self.active_layers_indices = [] self.active_layers_indices = []
def freeze_all_layers(self): def freeze_all_layers(self, summarize=False):
layers = reduce( layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model getattr, self.layers_attribute.split("."), self.trainer.model
) )
for layer in layers: for layer in layers:
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
if summarize:
LOG.info(
f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
)
def on_step_begin( def on_step_begin(
self, args, state, control, **kwargs self, args, state, control, **kwargs
@@ -63,7 +70,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.active_layers_indices = np.random.choice( self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False range(self.total_layers), self.n_layers, replace=False
) )
print( LOG.info(
f"Activating layers at indices: {self.active_layers_indices} for the next steps." f"Activating layers at indices: {self.active_layers_indices} for the next steps."
) )