From 2fd5c45c2e60c87f9bd9ef32d025b6ceb4091e39 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 5 Feb 2025 18:03:04 +0700 Subject: [PATCH] chore: refactor register linear llama --- src/axolotl/cli/convert_linear_attention.py | 6 ---- src/axolotl/integrations/lolcats/__init__.py | 10 ++++++ .../linear_llama/modeling_linear_llama.py | 34 +++++++++++-------- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/axolotl/cli/convert_linear_attention.py b/src/axolotl/cli/convert_linear_attention.py index cacde2b56..dae7d7bb0 100644 --- a/src/axolotl/cli/convert_linear_attention.py +++ b/src/axolotl/cli/convert_linear_attention.py @@ -60,12 +60,6 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: model, config=linear_llama_config, train_attention=True ) - # register model - from transformers import AutoConfig, AutoModel - - AutoConfig.register("linear_llama", LinearLlamaConfig) - AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM) - # set save_path, save tokenizer and model config. save_path = str(os.path.join(cfg.output_dir, "distilled")) tokenizer.save_pretrained(save_path) diff --git a/src/axolotl/integrations/lolcats/__init__.py b/src/axolotl/integrations/lolcats/__init__.py index d9d916c0f..0eb6eee7e 100644 --- a/src/axolotl/integrations/lolcats/__init__.py +++ b/src/axolotl/integrations/lolcats/__init__.py @@ -21,6 +21,16 @@ class LinearizePlugin(BasePlugin): Plugin for lolcats integration with Axolotl. """ + def __init__(self): + super().__init__() + + # Register the Linear Llama model with transformers + from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import ( + register_linear_llama, + ) + + register_linear_llama() + def get_input_args(self): return "axolotl.integrations.lolcats.LinearAttentionArgs" diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index 9aeba81c5..314584051 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -94,7 +94,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): @classmethod def from_llama( cls, - model: LlamaModel | LlamaForCausalLM, + model: LlamaForCausalLM, config: LinearLlamaConfig, train_attention: bool = False, remove_base_attn: bool = True, @@ -103,30 +103,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): Initialize a LinearLlamaForCausalLM from a LlamaModel """ - # Handle LlamaForCausalLM - if isinstance(model, LlamaForCausalLM): - llama_model = model.model - else: - llama_model = model - if config is None: raise ValueError("Missing config") + # initialize the model with prior weights + new_model = cls(config=config) + from axolotl.integrations.lolcats.linearize_attention import convert_attention - llama_model = convert_attention( - llama_model, + new_model.model = convert_attention( + model.model, DictDefault(**config.attention_config), train_attention=train_attention, remove_base_attn=remove_base_attn, ) - # initialize the model with prior weights - new_model = cls(config=config) - del new_model.model # remove the default model - del new_model.lm_head # remove the default lm_head - new_model.model = llama_model - new_model.lm_head = model.lm_head + new_model.lm_head.load_state_dict(model.lm_head.state_dict()) return new_model @@ -147,3 +139,15 @@ class LinearLlamaForCausalLM(LlamaForCausalLM): ) remove_base_attention(self.model) + + +def register_linear_llama(): + """ + Register Linear LLaMA model with the Transformers library. + """ + + from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + + AutoConfig.register("linear_llama", LinearLlamaConfig) + AutoModel.register(LinearLlamaConfig, LinearLlamaModel) + AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)