diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index 4df3225bb..6509cd279 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -6,6 +6,7 @@ Arxiv: https://arxiv.org/abs/2403.17919 License: Apache 2.0 """ +import logging from functools import reduce from typing import TYPE_CHECKING @@ -15,6 +16,8 @@ from transformers import TrainerCallback if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer +LOG = logging.getLogger("axolotl.callbacks") + def lisa_callback_factory(trainer: "AxolotlTrainer"): class LISACallback(TrainerCallback): @@ -34,16 +37,20 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): self.total_layers = len( reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) - self.freeze_all_layers() + self.freeze_all_layers(True) self.active_layers_indices = [] - def freeze_all_layers(self): + def freeze_all_layers(self, summarize=False): layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) for layer in layers: for param in layer.parameters(): param.requires_grad = False + if summarize: + LOG.info( + f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" + ) def on_step_begin( self, args, state, control, **kwargs @@ -63,7 +70,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): self.active_layers_indices = np.random.choice( range(self.total_layers), self.n_layers, replace=False ) - print( + LOG.info( f"Activating layers at indices: {self.active_layers_indices} for the next steps." )