From 3a9ad7c66e99ce4f68ca59ae559bb252cbf5ed97 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 30 Mar 2024 22:55:15 -0400 Subject: [PATCH] add lisa support --- src/axolotl/core/trainer_builder.py | 24 ++++++ src/axolotl/utils/callbacks/lisa.py | 73 +++++++++++++++++++ .../config/models/input/v0_4_1/__init__.py | 18 +++++ 3 files changed, 115 insertions(+) create mode 100644 src/axolotl/utils/callbacks/lisa.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4d85b40de..cc7275184 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -45,6 +45,7 @@ from axolotl.utils.callbacks import ( causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, ) +from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments): orpo_alpha: Optional[float] = field( default=None, ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) class AxolotlTrainer(Trainer): @@ -938,6 +951,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) callbacks.append(early_stop_cb) + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + callbacks.append(lisa_callback_factory(trainer)) return callbacks def _get_trainer_cls(self): @@ -1229,6 +1244,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "relora_prune_ratio" ] = self.cfg.relora_prune_ratio + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: + training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers + training_arguments_kwargs[ + "lisa_step_interval" + ] = self.cfg.lisa_step_interval + training_arguments_kwargs[ + "lisa_layers_attribute" + ] = self.cfg.lisa_layers_attribute + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py new file mode 100644 index 000000000..a42f48bb8 --- /dev/null +++ b/src/axolotl/utils/callbacks/lisa.py @@ -0,0 +1,73 @@ +"""module for LISA""" +import ast +from typing import TYPE_CHECKING + +import numpy as np +from transformers import TrainerCallback + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainer + + +def lisa_callback_factory(trainer: "AxolotlTrainer"): + class LISACallback(TrainerCallback): + """trainer callback for lisa layer switching""" + + def __init__( + self, n_layers, step_interval, trainer, layers_attribute="model.layers" + ): + super().__init__() + self.n_layers = n_layers + self.step_interval = step_interval + self.layers_attribute = layers_attribute + self.trainer = trainer + + self.total_layers = len( + ast.literal_eval("self.trainer.model." + self.layers_attribute) + ) + self.freeze_all_layers() + self.active_layers_indices = [] + + def freeze_all_layers(self): + layers = ast.literal_eval( + "self.trainer.model." + self.layers_attribute + ) # Dynamically execute to get layers + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def on_step_begin( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # Check if it's time to switch active layers, including at step 0 + if state.global_step % self.step_interval == 0 or state.global_step == 1: + self.switch_active_layers() + + def switch_active_layers(self): + # First, disable gradients for all layers + self.freeze_all_layers() + + # Randomly select n_layers to activate + layers = ast.literal_eval( + "self.trainer.model" + self.layers_attribute + ) # Re-fetch layer references + self.active_layers_indices = np.random.choice( + range(self.total_layers), self.n_layers, replace=False + ) + print( + f"Activating layers at indices: {self.active_layers_indices} for the next steps." + ) + + # Enable gradients only for the selected layers + for idx in self.active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True + + lisa_callback = LISACallback( + n_layers=trainer.args.lisa_n_layers, + step_interval=trainer.args.lisa_step_interval, + trainer=trainer, + layers_attribute=trainer.args.lisa_layers_attribute, + ) + + return lisa_callback diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c07c0ff75..c66ae70d4 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel): hf_mlflow_log_artifacts: Optional[bool] = None +class LISAConfig(BaseModel): + """LISA options""" + + lisa_n_layers: Optional[int] = Field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = Field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = Field( + default="", + metadata={"help": "path under the model to access the layers"}, + ) + + class WandbConfig(BaseModel): """wandb configuration subset""" @@ -404,6 +421,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + LISAConfig, RemappedParameters, DeprecatedParameters, BaseModel,