fix LISA by ensuring params are not frozen during __init__
This commit is contained in:
@@ -16,7 +16,7 @@ from transformers import TrainerCallback
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
||||
|
||||
|
||||
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
@@ -37,20 +37,22 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
self.total_layers = len(
|
||||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||
)
|
||||
self.freeze_all_layers(True)
|
||||
self.active_layers_indices = []
|
||||
|
||||
def freeze_all_layers(self, summarize=False):
|
||||
layers = reduce(
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
LOG.info(
|
||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
||||
)
|
||||
|
||||
def freeze_all_layers(self):
|
||||
layers = reduce(
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
for layer in layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
if summarize:
|
||||
LOG.info(
|
||||
f"Freezing {len(layers)} layers; will activate {self.n_layers} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
||||
)
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user