Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
05f7034288 use deterministic seed for random LISA layers 2024-04-04 18:16:55 -07:00

View File

@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
def on_train_begin(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
self.switch_active_layers(state)
def on_step_begin( def on_step_begin(
self, args, state, control, **kwargs self, args, state, control, **kwargs
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# Check if it's time to switch active layers, including at step 0 # Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1: if state.global_step % self.step_interval == 0:
self.switch_active_layers() self.switch_active_layers(state)
def switch_active_layers(self): def switch_active_layers(self, state):
# First, disable gradients for all layers # First, disable gradients for all layers
self.freeze_all_layers() self.freeze_all_layers()
deterministic_seed = state.global_step
np.random.seed(deterministic_seed)
# Randomly select n_layers to activate # Randomly select n_layers to activate
layers = reduce( layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model getattr, self.layers_attribute.split("."), self.trainer.model
) )
self.active_layers_indices = np.random.choice( self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False range(self.total_layers),
self.n_layers,
replace=False,
) )
LOG.info( LOG.info(
f"Activating layers at indices: {self.active_layers_indices} for the next steps." f"Activating layers at indices: {self.active_layers_indices} for the next steps."