Compare commits
1 Commits
streaming-
...
20240404-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05f7034288 |
@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def on_train_begin(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
self.switch_active_layers(state)
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
# Check if it's time to switch active layers, including at step 0
|
||||
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
||||
self.switch_active_layers()
|
||||
if state.global_step % self.step_interval == 0:
|
||||
self.switch_active_layers(state)
|
||||
|
||||
def switch_active_layers(self):
|
||||
def switch_active_layers(self, state):
|
||||
# First, disable gradients for all layers
|
||||
self.freeze_all_layers()
|
||||
|
||||
deterministic_seed = state.global_step
|
||||
np.random.seed(deterministic_seed)
|
||||
|
||||
# Randomly select n_layers to activate
|
||||
layers = reduce(
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
self.active_layers_indices = np.random.choice(
|
||||
range(self.total_layers), self.n_layers, replace=False
|
||||
range(self.total_layers),
|
||||
self.n_layers,
|
||||
replace=False,
|
||||
)
|
||||
LOG.info(
|
||||
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
||||
|
||||
Reference in New Issue
Block a user