fix default and fix attribute traversal for layers
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
"""module for LISA"""
|
||||
import ast
|
||||
"""
|
||||
module for LISA
|
||||
|
||||
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
|
||||
Arxiv: https://arxiv.org/abs/2403.17919
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
@@ -22,16 +29,18 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
self.layers_attribute = layers_attribute
|
||||
self.trainer = trainer
|
||||
|
||||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||
|
||||
self.total_layers = len(
|
||||
ast.literal_eval("self.trainer.model." + self.layers_attribute)
|
||||
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||
)
|
||||
self.freeze_all_layers()
|
||||
self.active_layers_indices = []
|
||||
|
||||
def freeze_all_layers(self):
|
||||
layers = ast.literal_eval(
|
||||
"self.trainer.model." + self.layers_attribute
|
||||
) # Dynamically execute to get layers
|
||||
layers = reduce(
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
for layer in layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
@@ -48,9 +57,9 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
self.freeze_all_layers()
|
||||
|
||||
# Randomly select n_layers to activate
|
||||
layers = ast.literal_eval(
|
||||
"self.trainer.model" + self.layers_attribute
|
||||
) # Re-fetch layer references
|
||||
layers = reduce(
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
self.active_layers_indices = np.random.choice(
|
||||
range(self.total_layers), self.n_layers, replace=False
|
||||
)
|
||||
|
||||
@@ -382,7 +382,7 @@ class LISAConfig(BaseModel):
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = Field(
|
||||
default="",
|
||||
default="model.layers",
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user