chore: refactor register linear llama
This commit is contained in:
@@ -60,12 +60,6 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
model, config=linear_llama_config, train_attention=True
|
model, config=linear_llama_config, train_attention=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# register model
|
|
||||||
from transformers import AutoConfig, AutoModel
|
|
||||||
|
|
||||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
|
||||||
AutoModel.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
|
||||||
|
|
||||||
# set save_path, save tokenizer and model config.
|
# set save_path, save tokenizer and model config.
|
||||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||||
tokenizer.save_pretrained(save_path)
|
tokenizer.save_pretrained(save_path)
|
||||||
|
|||||||
@@ -21,6 +21,16 @@ class LinearizePlugin(BasePlugin):
|
|||||||
Plugin for lolcats integration with Axolotl.
|
Plugin for lolcats integration with Axolotl.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Register the Linear Llama model with transformers
|
||||||
|
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||||
|
register_linear_llama,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_linear_llama()
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||||
|
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(
|
def from_llama(
|
||||||
cls,
|
cls,
|
||||||
model: LlamaModel | LlamaForCausalLM,
|
model: LlamaForCausalLM,
|
||||||
config: LinearLlamaConfig,
|
config: LinearLlamaConfig,
|
||||||
train_attention: bool = False,
|
train_attention: bool = False,
|
||||||
remove_base_attn: bool = True,
|
remove_base_attn: bool = True,
|
||||||
@@ -103,30 +103,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Handle LlamaForCausalLM
|
|
||||||
if isinstance(model, LlamaForCausalLM):
|
|
||||||
llama_model = model.model
|
|
||||||
else:
|
|
||||||
llama_model = model
|
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("Missing config")
|
raise ValueError("Missing config")
|
||||||
|
|
||||||
|
# initialize the model with prior weights
|
||||||
|
new_model = cls(config=config)
|
||||||
|
|
||||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||||
|
|
||||||
llama_model = convert_attention(
|
new_model.model = convert_attention(
|
||||||
llama_model,
|
model.model,
|
||||||
DictDefault(**config.attention_config),
|
DictDefault(**config.attention_config),
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# initialize the model with prior weights
|
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||||
new_model = cls(config=config)
|
|
||||||
del new_model.model # remove the default model
|
|
||||||
del new_model.lm_head # remove the default lm_head
|
|
||||||
new_model.model = llama_model
|
|
||||||
new_model.lm_head = model.lm_head
|
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
@@ -147,3 +139,15 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
remove_base_attention(self.model)
|
remove_base_attention(self.model)
|
||||||
|
|
||||||
|
|
||||||
|
def register_linear_llama():
|
||||||
|
"""
|
||||||
|
Register Linear LLaMA model with the Transformers library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
|
|
||||||
|
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||||
|
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||||
|
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||||
|
|||||||
Reference in New Issue
Block a user