diff --git a/src/axolotl/cli/convert_linear_attention.py b/src/axolotl/cli/convert_linear_attention.py index 061e3fe8b..dee2de05c 100644 --- a/src/axolotl/cli/convert_linear_attention.py +++ b/src/axolotl/cli/convert_linear_attention.py @@ -16,11 +16,14 @@ from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer from axolotl.common.datasets import load_datasets from axolotl.integrations.base import PluginManager -from axolotl.integrations.lolcats.linearize_attention import ( - remove_base_attention, - toggle_attention, +from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import ( + LinearLlamaConfig, +) +from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import ( + LinearLlamaForCausalLM, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model_config from axolotl.utils.trainer import setup_trainer LOG = logging.getLogger(__name__) @@ -42,11 +45,15 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: # load model model, tokenizer = load_model_and_tokenizer(cfg=cfg) - # convert attention - from axolotl.integrations.lolcats.linearize_attention import convert_attention + # load config + base_config = load_model_config(cfg) - model = convert_attention( - model, cfg.attention_config, train_attention=True, remove_base_attn=True + # convert to linear llama + linear_llama_config = LinearLlamaConfig.from_llama( + base_config, cfg.attention_config + ) + model = LinearLlamaForCausalLM.from_llama( + model, config=linear_llama_config, train_attention=True ) # Get datasets @@ -56,7 +63,7 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: total_num_steps = dataset_meta.total_num_steps # toggle attention to be trainable - model = toggle_attention(model, train=True) + model.toggle_attention(train=True) # Setup trainer trainer = setup_trainer( @@ -73,8 +80,8 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) # drop base_attention + remove training attn - model = toggle_attention(model, train=False) - model = remove_base_attention(model) + model.toggle_attention(train=False) + model.remove_base_attention() # NOTE: If in peft mode, consider whether to auto-merge diff --git a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py index 31e81a274..2eb0ec8a4 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py @@ -62,3 +62,18 @@ class LinearLlamaConfig(LlamaConfig): # Set default attention config if none provided self.attention_config = attention_config + + @classmethod + def from_llama(cls, llama_config: LlamaConfig, attention_config: dict): + """ + Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config. + + Args: + llama_config (:class:`~transformers.LlamaConfig`): + The LlamaConfig to inherit from. + + attention_config (`dict`): + Dictionary containing the configuration for linear attention mechanism. + """ + + return cls(attention_config=attention_config, **llama_config.to_dict()) diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index 937735e0d..57ea6cacb 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -75,6 +75,10 @@ class LinearLlamaModel(LlamaModel): class LinearLlamaForCausalLM(LlamaForCausalLM): + """ + Linear LLaMA model for causal language modeling. + """ + def __init__(self, config): super().__init__(config) self.model = LinearLlamaModel(config) @@ -113,3 +117,21 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): ) return new_model + + def toggle_attention(self, train: bool = True): + """ + Toggle attention to be trainable or not + """ + from axolotl.integrations.lolcats.linearize_attention import toggle_attention + + toggle_attention(self.model, train=train) + + def remove_base_attention(self): + """ + Remove base attention after distillation + """ + from axolotl.integrations.lolcats.linearize_attention import ( + remove_base_attention, + ) + + remove_base_attention(self.model)