chore: refactor register linear llama

This commit is contained in:
NanoCode012
2025-02-05 18:03:04 +07:00
parent 8294e6218f
commit 2fd5c45c2e
3 changed files with 29 additions and 21 deletions

View File

@@ -60,12 +60,6 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
model, config=linear_llama_config, train_attention=True
)
# register model
from transformers import AutoConfig, AutoModel
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)

View File

@@ -21,6 +21,16 @@ class LinearizePlugin(BasePlugin):
Plugin for lolcats integration with Axolotl.
"""
def __init__(self):
super().__init__()
# Register the Linear Llama model with transformers
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
register_linear_llama,
)
register_linear_llama()
def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs"

View File

@@ -94,7 +94,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
@classmethod
def from_llama(
cls,
model: LlamaModel | LlamaForCausalLM,
model: LlamaForCausalLM,
config: LinearLlamaConfig,
train_attention: bool = False,
remove_base_attn: bool = True,
@@ -103,30 +103,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
Initialize a LinearLlamaForCausalLM from a LlamaModel
"""
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
llama_model = model.model
else:
llama_model = model
if config is None:
raise ValueError("Missing config")
# initialize the model with prior weights
new_model = cls(config=config)
from axolotl.integrations.lolcats.linearize_attention import convert_attention
llama_model = convert_attention(
llama_model,
new_model.model = convert_attention(
model.model,
DictDefault(**config.attention_config),
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
# initialize the model with prior weights
new_model = cls(config=config)
del new_model.model # remove the default model
del new_model.lm_head # remove the default lm_head
new_model.model = llama_model
new_model.lm_head = model.lm_head
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
@@ -147,3 +139,15 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
)
remove_base_attention(self.model)
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)