improve lisa callback logging
This commit is contained in:
@@ -6,6 +6,7 @@ Arxiv: https://arxiv.org/abs/2403.17919
|
|||||||
License: Apache 2.0
|
License: Apache 2.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -15,6 +16,8 @@ from transformers import TrainerCallback
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
from axolotl.core.trainer_builder import AxolotlTrainer
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||||
class LISACallback(TrainerCallback):
|
class LISACallback(TrainerCallback):
|
||||||
@@ -34,16 +37,20 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
self.total_layers = len(
|
self.total_layers = len(
|
||||||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||||
)
|
)
|
||||||
self.freeze_all_layers()
|
self.freeze_all_layers(True)
|
||||||
self.active_layers_indices = []
|
self.active_layers_indices = []
|
||||||
|
|
||||||
def freeze_all_layers(self):
|
def freeze_all_layers(self, summarize=False):
|
||||||
layers = reduce(
|
layers = reduce(
|
||||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
)
|
)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
if summarize:
|
||||||
|
LOG.info(
|
||||||
|
f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
||||||
|
)
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args, state, control, **kwargs
|
self, args, state, control, **kwargs
|
||||||
@@ -63,7 +70,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
self.active_layers_indices = np.random.choice(
|
self.active_layers_indices = np.random.choice(
|
||||||
range(self.total_layers), self.n_layers, replace=False
|
range(self.total_layers), self.n_layers, replace=False
|
||||||
)
|
)
|
||||||
print(
|
LOG.info(
|
||||||
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
f"Activating layers at indices: {self.active_layers_indices} for the next steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user