refactor: move to modeling file and remove axolotl imports

This commit is contained in:
NanoCode012
2025-02-05 18:16:39 +07:00
parent 2fd5c45c2e
commit 2d5f692fc0
2 changed files with 204 additions and 227 deletions

View File

@@ -9,8 +9,12 @@
"""Linear LLaMA model implementation."""
import logging
from functools import partial
from typing import Any
from torch import nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
@@ -19,11 +23,11 @@ from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
)
from axolotl.utils.dict import DictDefault
from .attention import LolcatsLinearAttention
from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
"""
@@ -109,11 +113,10 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
# initialize the model with prior weights
new_model = cls(config=config)
from axolotl.integrations.lolcats.linearize_attention import convert_attention
del new_model.model # remove the default model
new_model.model = convert_attention(
model.model,
DictDefault(**config.attention_config),
attention_config=config.attention_config,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
@@ -126,7 +129,6 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Toggle attention to be trainable or not
"""
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
toggle_attention(self.model, train=train)
@@ -134,13 +136,206 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Remove base attention after distillation
"""
from axolotl.integrations.lolcats.linearize_attention import (
remove_base_attention,
)
remove_base_attention(self.model)
def convert_attention(
model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
# Get the layers to convert if provided
softmax_attns = attention_config.get("softmax_attentions", [])
# Get the attention to convert to
attention_type = attention_config.get("attention_type")
if attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else:
# Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: list[nn.Module], # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .attention import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .attention import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .attention import LolcatsLinearSlidingWindowAttention
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .attention import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .attention import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .attention import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .attention import LinearAttentionState
return LinearAttentionState()
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.

View File

@@ -1,218 +0,0 @@
"""
Convert attention to linear attention
Adapted from: https://github.com/HazyResearch/lolcats/blob/main/src/model/convert_model.py
@misc{zhang2024lolcatslowranklinearizinglarge,
title={LoLCATs: On Low-Rank Linearizing of Large Language Models},
author={Michael Zhang and Simran Arora and Rahul Chalamala and Alan Wu and Benjamin Spector and Aaryan Singhal and Krithik Ramesh and Christopher Ré},
year={2024},
eLOG.info={2410.10254},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10254},
}
"""
import logging
from functools import partial
from typing import Any
import torch.nn as nn
from tqdm import tqdm
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.integrations.lolcats.linearize_attention")
def convert_attention(
model: nn.Module,
attention_config: DictDefault,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
softmax_attns = []
if "softmax_attentions" in attention_config:
softmax_attns = attention_config["softmax_attentions"]
if attention_config.attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else: # Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: DictDefault,
layers: list[nn.Module], # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .linear_llama.attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .linear_llama.attention import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .linear_llama.attention import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .linear_llama.attention import LolcatsLinearSlidingWindowAttention
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .linear_llama.attention import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .linear_llama.attention import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_llama.attention import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .linear_llama.attention import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .linear_llama.attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .linear_llama.attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .linear_llama.attention import LinearAttentionState
return LinearAttentionState()