fix default and fix attribute traversal for layers
This commit is contained in:
@@ -1,5 +1,12 @@
|
|||||||
"""module for LISA"""
|
"""
|
||||||
import ast
|
module for LISA
|
||||||
|
|
||||||
|
Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
|
||||||
|
Arxiv: https://arxiv.org/abs/2403.17919
|
||||||
|
License: Apache 2.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import reduce
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -22,16 +29,18 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
self.layers_attribute = layers_attribute
|
self.layers_attribute = layers_attribute
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
|
|
||||||
|
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||||
|
|
||||||
self.total_layers = len(
|
self.total_layers = len(
|
||||||
ast.literal_eval("self.trainer.model." + self.layers_attribute)
|
reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
|
||||||
)
|
)
|
||||||
self.freeze_all_layers()
|
self.freeze_all_layers()
|
||||||
self.active_layers_indices = []
|
self.active_layers_indices = []
|
||||||
|
|
||||||
def freeze_all_layers(self):
|
def freeze_all_layers(self):
|
||||||
layers = ast.literal_eval(
|
layers = reduce(
|
||||||
"self.trainer.model." + self.layers_attribute
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
) # Dynamically execute to get layers
|
)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
@@ -48,9 +57,9 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
|||||||
self.freeze_all_layers()
|
self.freeze_all_layers()
|
||||||
|
|
||||||
# Randomly select n_layers to activate
|
# Randomly select n_layers to activate
|
||||||
layers = ast.literal_eval(
|
layers = reduce(
|
||||||
"self.trainer.model" + self.layers_attribute
|
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||||
) # Re-fetch layer references
|
)
|
||||||
self.active_layers_indices = np.random.choice(
|
self.active_layers_indices = np.random.choice(
|
||||||
range(self.total_layers), self.n_layers, replace=False
|
range(self.total_layers), self.n_layers, replace=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -382,7 +382,7 @@ class LISAConfig(BaseModel):
|
|||||||
metadata={"help": "how often to switch layers in LISA"},
|
metadata={"help": "how often to switch layers in LISA"},
|
||||||
)
|
)
|
||||||
lisa_layers_attribute: Optional[str] = Field(
|
lisa_layers_attribute: Optional[str] = Field(
|
||||||
default="",
|
default="model.layers",
|
||||||
metadata={"help": "path under the model to access the layers"},
|
metadata={"help": "path under the model to access the layers"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user