chore: refactor register linear llama
This commit is contained in:
@@ -60,12 +60,6 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# register model
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
# set save_path, save tokenizer and model config.
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
|
||||
@@ -21,6 +21,16 @@ class LinearizePlugin(BasePlugin):
|
||||
Plugin for lolcats integration with Axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Register the Linear Llama model with transformers
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
register_linear_llama,
|
||||
)
|
||||
|
||||
register_linear_llama()
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: LlamaModel | LlamaForCausalLM,
|
||||
model: LlamaForCausalLM,
|
||||
config: LinearLlamaConfig,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
@@ -103,30 +103,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||
"""
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
if isinstance(model, LlamaForCausalLM):
|
||||
llama_model = model.model
|
||||
else:
|
||||
llama_model = model
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing config")
|
||||
|
||||
# initialize the model with prior weights
|
||||
new_model = cls(config=config)
|
||||
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
|
||||
llama_model = convert_attention(
|
||||
llama_model,
|
||||
new_model.model = convert_attention(
|
||||
model.model,
|
||||
DictDefault(**config.attention_config),
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
# initialize the model with prior weights
|
||||
new_model = cls(config=config)
|
||||
del new_model.model # remove the default model
|
||||
del new_model.lm_head # remove the default lm_head
|
||||
new_model.model = llama_model
|
||||
new_model.lm_head = model.lm_head
|
||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||
|
||||
return new_model
|
||||
|
||||
@@ -147,3 +139,15 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
)
|
||||
|
||||
remove_base_attention(self.model)
|
||||
|
||||
|
||||
def register_linear_llama():
|
||||
"""
|
||||
Register Linear LLaMA model with the Transformers library.
|
||||
"""
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
Reference in New Issue
Block a user