diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index a42f48bb8..4df3225bb 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -1,5 +1,12 @@ -"""module for LISA""" -import ast +""" +module for LISA + +Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl +Arxiv: https://arxiv.org/abs/2403.17919 +License: Apache 2.0 +""" + +from functools import reduce from typing import TYPE_CHECKING import numpy as np @@ -22,16 +29,18 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): self.layers_attribute = layers_attribute self.trainer = trainer + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) + self.total_layers = len( - ast.literal_eval("self.trainer.model." + self.layers_attribute) + reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) self.freeze_all_layers() self.active_layers_indices = [] def freeze_all_layers(self): - layers = ast.literal_eval( - "self.trainer.model." + self.layers_attribute - ) # Dynamically execute to get layers + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) for layer in layers: for param in layer.parameters(): param.requires_grad = False @@ -48,9 +57,9 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"): self.freeze_all_layers() # Randomly select n_layers to activate - layers = ast.literal_eval( - "self.trainer.model" + self.layers_attribute - ) # Re-fetch layer references + layers = reduce( + getattr, self.layers_attribute.split("."), self.trainer.model + ) self.active_layers_indices = np.random.choice( range(self.total_layers), self.n_layers, replace=False ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c66ae70d4..5a927602f 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -382,7 +382,7 @@ class LISAConfig(BaseModel): metadata={"help": "how often to switch layers in LISA"}, ) lisa_layers_attribute: Optional[str] = Field( - default="", + default="model.layers", metadata={"help": "path under the model to access the layers"}, )