From 6185cd522776bf62c406a677be50dab9d7aca7d3 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 1 Apr 2024 06:57:28 +0000 Subject: [PATCH] fix LISA by ensuring params are not frozen during __init__ --- src/axolotl/utils/callbacks/lisa.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index 6509cd279..ff20959a5 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -16,7 +16,7 @@ from transformers import TrainerCallback if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer -LOG = logging.getLogger("axolotl.callbacks") +LOG = logging.getLogger("axolotl.callbacks.lisa") def lisa_callback_factory(trainer: "AxolotlTrainer"): @@ -37,20 +37,22 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): self.total_layers = len( reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) - self.freeze_all_layers(True) self.active_layers_indices = [] - def freeze_all_layers(self, summarize=False): + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) + LOG.info( + f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" + ) + + def freeze_all_layers(self): layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) for layer in layers: for param in layer.parameters(): param.requires_grad = False - if summarize: - LOG.info( - f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" - ) def on_step_begin( self, args, state, control, **kwargs