Compare commits
1 Commits
runpod-sls
...
20240404-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05f7034288 |
@@ -54,23 +54,33 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def on_train_begin(
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
self.switch_active_layers(state)
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# Check if it's time to switch active layers, including at step 0
|
# Check if it's time to switch active layers, including at step 0
|
||||||
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
if state.global_step % self.step_interval == 0:
|
||||||
self.switch_active_layers()
|
self.switch_active_layers(state)
|
||||||
|
|
||||||
def switch_active_layers(self):
|
def switch_active_layers(self, state):
|
||||||
# First, disable gradients for all layers
|
# First, disable gradients for all layers
|
||||||
self.freeze_all_layers()
|
self.freeze_all_layers()
|
||||||
|
|
||||||
|
deterministic_seed = state.global_step
|
||||||
|
np.random.seed(deterministic_seed)
|
||||||
|
|
||||||
# Randomly select n_layers to activate
|
# Randomly select n_layers to activate
|
||||||
layers = reduce(
|
layers = reduce(
|
||||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
)
|
)
|
||||||
self.active_layers_indices = np.random.choice(
|
self.active_layers_indices = np.random.choice(
|
||||||
range(self.total_layers), self.n_layers, replace=False
|
range(self.total_layers),
|
||||||
|
self.n_layers,
|
||||||
|
replace=False,
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
||||||
|
|||||||
Reference in New Issue
Block a user