chore: refactor register linear llama

This commit is contained in:
NanoCode012
2025-02-05 18:03:04 +07:00
parent 8294e6218f
commit 2fd5c45c2e
3 changed files with 29 additions and 21 deletions

View File

@@ -60,12 +60,6 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
model, config=linear_llama_config, train_attention=True model, config=linear_llama_config, train_attention=True
) )
# register model
from transformers import AutoConfig, AutoModel
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# set save_path, save tokenizer and model config. # set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled")) save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path) tokenizer.save_pretrained(save_path)

View File

@@ -21,6 +21,16 @@ class LinearizePlugin(BasePlugin):
Plugin for lolcats integration with Axolotl. Plugin for lolcats integration with Axolotl.
""" """
def __init__(self):
super().__init__()
# Register the Linear Llama model with transformers
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
register_linear_llama,
)
register_linear_llama()
def get_input_args(self): def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs" return "axolotl.integrations.lolcats.LinearAttentionArgs"

View File

@@ -94,7 +94,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
@classmethod @classmethod
def from_llama( def from_llama(
cls, cls,
model: LlamaModel | LlamaForCausalLM, model: LlamaForCausalLM,
config: LinearLlamaConfig, config: LinearLlamaConfig,
train_attention: bool = False, train_attention: bool = False,
remove_base_attn: bool = True, remove_base_attn: bool = True,
@@ -103,30 +103,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
Initialize a LinearLlamaForCausalLM from a LlamaModel Initialize a LinearLlamaForCausalLM from a LlamaModel
""" """
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
llama_model = model.model
else:
llama_model = model
if config is None: if config is None:
raise ValueError("Missing config") raise ValueError("Missing config")
# initialize the model with prior weights
new_model = cls(config=config)
from axolotl.integrations.lolcats.linearize_attention import convert_attention from axolotl.integrations.lolcats.linearize_attention import convert_attention
llama_model = convert_attention( new_model.model = convert_attention(
llama_model, model.model,
DictDefault(**config.attention_config), DictDefault(**config.attention_config),
train_attention=train_attention, train_attention=train_attention,
remove_base_attn=remove_base_attn, remove_base_attn=remove_base_attn,
) )
# initialize the model with prior weights new_model.lm_head.load_state_dict(model.lm_head.state_dict())
new_model = cls(config=config)
del new_model.model # remove the default model
del new_model.lm_head # remove the default lm_head
new_model.model = llama_model
new_model.lm_head = model.lm_head
return new_model return new_model
@@ -147,3 +139,15 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
) )
remove_base_attention(self.model) remove_base_attention(self.model)
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)