refactor: move to modeling file and remove axolotl imports
This commit is contained in:
@@ -9,8 +9,12 @@
|
||||
|
||||
"""Linear LLaMA model implementation."""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
@@ -19,11 +23,11 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .attention import LolcatsLinearAttention
|
||||
from .configuration_linear_llama import LinearLlamaConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
@@ -109,11 +113,10 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
# initialize the model with prior weights
|
||||
new_model = cls(config=config)
|
||||
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
|
||||
del new_model.model # remove the default model
|
||||
new_model.model = convert_attention(
|
||||
model.model,
|
||||
DictDefault(**config.attention_config),
|
||||
attention_config=config.attention_config,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
@@ -126,7 +129,6 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Toggle attention to be trainable or not
|
||||
"""
|
||||
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
|
||||
|
||||
toggle_attention(self.model, train=train)
|
||||
|
||||
@@ -134,13 +136,206 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Remove base attention after distillation
|
||||
"""
|
||||
from axolotl.integrations.lolcats.linearize_attention import (
|
||||
remove_base_attention,
|
||||
)
|
||||
|
||||
remove_base_attention(self.model)
|
||||
|
||||
|
||||
def convert_attention(
|
||||
model: nn.Module,
|
||||
attention_config: dict,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Call to convert all attention layers
|
||||
"""
|
||||
# Get the layers to convert if provided
|
||||
softmax_attns = attention_config.get("softmax_attentions", [])
|
||||
|
||||
# Get the attention to convert to
|
||||
attention_type = attention_config.get("attention_type")
|
||||
|
||||
if attention_type != "softmax":
|
||||
layers = traverse_layers(model)
|
||||
for layer_idx, layer in enumerate(
|
||||
tqdm(layers, desc="Converting attentions...")
|
||||
):
|
||||
if layer_idx not in softmax_attns:
|
||||
layer.self_attn = convert_llama_attention(
|
||||
layer,
|
||||
attention_config,
|
||||
layers,
|
||||
train_attention,
|
||||
remove_base_attn,
|
||||
)
|
||||
layer.self_attn.converted = True
|
||||
else:
|
||||
# Freeze any preserved softmax attention layers
|
||||
for p in layer.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
LOG.info(
|
||||
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||
"""
|
||||
Make attentions trainable if train is True
|
||||
-> Set train_attention = False when finetuning
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
layer.self_attn.train_attention = train
|
||||
return llama_model
|
||||
|
||||
|
||||
def remove_base_attention(llama_model: nn.Module):
|
||||
"""
|
||||
Remove teacher attention after distillation (if we keep it)
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
if getattr(layer.self_attn, "base_attn", False):
|
||||
del layer.self_attn.base_attn
|
||||
return llama_model
|
||||
|
||||
|
||||
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||
"""
|
||||
Return list of model layers
|
||||
"""
|
||||
try:
|
||||
layers = model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.model.layers")
|
||||
except AttributeError as e: # if base model
|
||||
if verbose:
|
||||
LOG.info(e)
|
||||
try:
|
||||
layers = model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.layers")
|
||||
except AttributeError as e1: # If we make a PEFT model
|
||||
if verbose:
|
||||
LOG.info(e1)
|
||||
layers = model.base_model.model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||
return layers
|
||||
|
||||
|
||||
def convert_llama_attention(
|
||||
layer: nn.Module,
|
||||
attention_config: dict,
|
||||
layers: list[nn.Module], # list of layers
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a single layer's attention layer as specified by attention_config
|
||||
"""
|
||||
return get_attention(**attention_config)(
|
||||
base_attn=layer.self_attn,
|
||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||
max_layer_idx=len(layers) - 1,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
|
||||
def get_attention(attention_type: str, **kwargs):
|
||||
"""
|
||||
Get the linear attention class; either purely linear or linear with sliding window
|
||||
-> 'linear' == 'lolcats_llama'
|
||||
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||
"""
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .attention import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .attention import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .attention import LolcatsLinearSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .attention import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .attention import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .attention import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
else:
|
||||
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||
return None
|
||||
|
||||
|
||||
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
"""
|
||||
Determine how we store past keys and values when generating
|
||||
"""
|
||||
if attention_type is None:
|
||||
return past_key_values
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .attention import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "softmax" in attention_type:
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
|
||||
|
||||
def register_linear_llama():
|
||||
"""
|
||||
Register Linear LLaMA model with the Transformers library.
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
"""
|
||||
Convert attention to linear attention
|
||||
|
||||
Adapted from: https://github.com/HazyResearch/lolcats/blob/main/src/model/convert_model.py
|
||||
|
||||
@misc{zhang2024lolcatslowranklinearizinglarge,
|
||||
title={LoLCATs: On Low-Rank Linearizing of Large Language Models},
|
||||
author={Michael Zhang and Simran Arora and Rahul Chalamala and Alan Wu and Benjamin Spector and Aaryan Singhal and Krithik Ramesh and Christopher Ré},
|
||||
year={2024},
|
||||
eLOG.info={2410.10254},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2410.10254},
|
||||
}
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.lolcats.linearize_attention")
|
||||
|
||||
|
||||
def convert_attention(
|
||||
model: nn.Module,
|
||||
attention_config: DictDefault,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Call to convert all attention layers
|
||||
"""
|
||||
softmax_attns = []
|
||||
if "softmax_attentions" in attention_config:
|
||||
softmax_attns = attention_config["softmax_attentions"]
|
||||
if attention_config.attention_type != "softmax":
|
||||
layers = traverse_layers(model)
|
||||
for layer_idx, layer in enumerate(
|
||||
tqdm(layers, desc="Converting attentions...")
|
||||
):
|
||||
if layer_idx not in softmax_attns:
|
||||
layer.self_attn = convert_llama_attention(
|
||||
layer,
|
||||
attention_config,
|
||||
layers,
|
||||
train_attention,
|
||||
remove_base_attn,
|
||||
)
|
||||
layer.self_attn.converted = True
|
||||
else: # Freeze any preserved softmax attention layers
|
||||
for p in layer.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
LOG.info(
|
||||
f"-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||
"""
|
||||
Make attentions trainable if train is True
|
||||
-> Set train_attention = False when finetuning
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
layer.self_attn.train_attention = train
|
||||
return llama_model
|
||||
|
||||
|
||||
def remove_base_attention(llama_model: nn.Module):
|
||||
"""
|
||||
Remove teacher attention after distillation (if we keep it)
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
if getattr(layer.self_attn, "base_attn", False):
|
||||
del layer.self_attn.base_attn
|
||||
return llama_model
|
||||
|
||||
|
||||
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||
"""
|
||||
Return list of model layers
|
||||
"""
|
||||
try:
|
||||
layers = model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.model.layers")
|
||||
except AttributeError as e: # if base model
|
||||
if verbose:
|
||||
LOG.info(e)
|
||||
try:
|
||||
layers = model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.layers")
|
||||
except AttributeError as e1: # If we make a PEFT model
|
||||
if verbose:
|
||||
LOG.info(e1)
|
||||
layers = model.base_model.model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||
return layers
|
||||
|
||||
|
||||
def convert_llama_attention(
|
||||
layer: nn.Module,
|
||||
attention_config: DictDefault,
|
||||
layers: list[nn.Module], # list of layers
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a single layer's attention layer as specified by attention_config
|
||||
"""
|
||||
return get_attention(**attention_config)(
|
||||
base_attn=layer.self_attn,
|
||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||
max_layer_idx=len(layers) - 1,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
|
||||
def get_attention(attention_type: str, **kwargs):
|
||||
"""
|
||||
Get the linear attention class; either purely linear or linear with sliding window
|
||||
-> 'linear' == 'lolcats_llama'
|
||||
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||
"""
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .linear_llama.attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .linear_llama.attention import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .linear_llama.attention import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .linear_llama.attention import LolcatsLinearSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .linear_llama.attention import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .linear_llama.attention import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_llama.attention import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
else:
|
||||
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||
return None
|
||||
|
||||
|
||||
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
"""
|
||||
Determine how we store past keys and values when generating
|
||||
"""
|
||||
if attention_type is None:
|
||||
return past_key_values
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .linear_llama.attention import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .linear_llama.attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .linear_llama.attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "softmax" in attention_type:
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .linear_llama.attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
Reference in New Issue
Block a user