fix default and fix attribute traversal for layers

This commit is contained in:
Wing Lian
2024-03-31 00:27:04 -04:00
parent 3a9ad7c66e
commit 21a5094226
2 changed files with 19 additions and 10 deletions

View File

@@ -1,5 +1,12 @@
"""module for LISA"""
import ast
"""
module for LISA
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0
"""
from functools import reduce
from typing import TYPE_CHECKING
import numpy as np
@@ -22,16 +29,18 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.layers_attribute = layers_attribute
self.trainer = trainer
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
self.total_layers = len(
ast.literal_eval("self.trainer.model." + self.layers_attribute)
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
)
self.freeze_all_layers()
self.active_layers_indices = []
def freeze_all_layers(self):
layers = ast.literal_eval(
"self.trainer.model." + self.layers_attribute
) # Dynamically execute to get layers
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
for layer in layers:
for param in layer.parameters():
param.requires_grad = False
@@ -48,9 +57,9 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
self.freeze_all_layers()
# Randomly select n_layers to activate
layers = ast.literal_eval(
"self.trainer.model" + self.layers_attribute
) # Re-fetch layer references
layers = reduce(
getattr, self.layers_attribute.split("."), self.trainer.model
)
self.active_layers_indices = np.random.choice(
range(self.total_layers), self.n_layers, replace=False
)

View File

@@ -382,7 +382,7 @@ class LISAConfig(BaseModel):
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = Field(
default="",
default="model.layers",
metadata={"help": "path under the model to access the layers"},
)