fix default and fix attribute traversal for layers

This commit is contained in:
Wing Lian
2024-03-31 00:27:04 -04:00
parent 3a9ad7c66e
commit 21a5094226
2 changed files with 19 additions and 10 deletions

View File

@@ -1,5 +1,12 @@
"""module for LISA""" """
import ast module for LISA
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0
"""
from functools import reduce
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np import numpy as np
@@ -22,16 +29,18 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.layers_attribute = layers_attribute self.layers_attribute = layers_attribute
self.trainer = trainer self.trainer = trainer
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
self.total_layers = len( self.total_layers = len(
ast.literal_eval("self.trainer.model." + self.layers_attribute) reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
) )
self.freeze_all_layers() self.freeze_all_layers()
self.active_layers_indices = [] self.active_layers_indices = []
def freeze_all_layers(self): def freeze_all_layers(self):
layers = ast.literal_eval( layers = reduce(
"self.trainer.model." + self.layers_attribute getattr, self.layers_attribute.split("."), self.trainer.model
) # Dynamically execute to get layers )
for layer in layers: for layer in layers:
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = False param.requires_grad = False
@@ -48,9 +57,9 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.freeze_all_layers() self.freeze_all_layers()
# Randomly select n_layers to activate # Randomly select n_layers to activate
layers = ast.literal_eval( layers = reduce(
"self.trainer.model" + self.layers_attribute getattr, self.layers_attribute.split("."), self.trainer.model
) # Re-fetch layer references )
self.active_layers_indices = np.random.choice( self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False range(self.total_layers), self.n_layers, replace=False
) )

View File

@@ -382,7 +382,7 @@ class LISAConfig(BaseModel):
metadata={"help": "how often to switch layers in LISA"}, metadata={"help": "how often to switch layers in LISA"},
) )
lisa_layers_attribute: Optional[str] = Field( lisa_layers_attribute: Optional[str] = Field(
default="", default="model.layers",
metadata={"help": "path under the model to access the layers"}, metadata={"help": "path under the model to access the layers"},
) )