From 05f70342883fdb75fb36edd09f5e8a72a25be3e9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Apr 2024 18:16:55 -0700 Subject: [PATCH] use deterministic seed for random LISA layers --- src/axolotl/utils/callbacks/lisa.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index ff20959a5..ed8cd1e43 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): for param in layer.parameters(): param.requires_grad = False + def on_train_begin( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + self.switch_active_layers(state) + def on_step_begin( self, args, state, control, **kwargs ): # pylint: disable=unused-argument # Check if it's time to switch active layers, including at step 0 - if state.global_step % self.step_interval == 0 or state.global_step == 1: - self.switch_active_layers() + if state.global_step % self.step_interval == 0: + self.switch_active_layers(state) - def switch_active_layers(self): + def switch_active_layers(self, state): # First, disable gradients for all layers self.freeze_all_layers() + deterministic_seed = state.global_step + np.random.seed(deterministic_seed) + # Randomly select n_layers to activate layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) self.active_layers_indices = np.random.choice( - range(self.total_layers), self.n_layers, replace=False + range(self.total_layers), + self.n_layers, + replace=False, ) LOG.info( f"Activating layers at indices: {self.active_layers_indices} for the next steps."