refactor: move to modeling file and remove axolotl imports

This commit is contained in:
NanoCode012
2025-02-05 18:16:39 +07:00
parent 2fd5c45c2e
commit 2d5f692fc0
2 changed files with 204 additions and 227 deletions

View File

@@ -9,8 +9,12 @@
"""Linear LLaMA model implementation.""" """Linear LLaMA model implementation."""
import logging
from functools import partial
from typing import Any
from torch import nn from torch import nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer, LlamaDecoderLayer,
LlamaForCausalLM, LlamaForCausalLM,
@@ -19,11 +23,11 @@ from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
) )
from axolotl.utils.dict import DictDefault
from .attention import LolcatsLinearAttention from .attention import LolcatsLinearAttention
from .configuration_linear_llama import LinearLlamaConfig from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
class LinearLlamaDecoderLayer(LlamaDecoderLayer): class LinearLlamaDecoderLayer(LlamaDecoderLayer):
""" """
@@ -109,11 +113,10 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
# initialize the model with prior weights # initialize the model with prior weights
new_model = cls(config=config) new_model = cls(config=config)
from axolotl.integrations.lolcats.linearize_attention import convert_attention del new_model.model # remove the default model
new_model.model = convert_attention( new_model.model = convert_attention(
model.model, model.model,
DictDefault(**config.attention_config), attention_config=config.attention_config,
train_attention=train_attention, train_attention=train_attention,
remove_base_attn=remove_base_attn, remove_base_attn=remove_base_attn,
) )
@@ -126,7 +129,6 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
""" """
Toggle attention to be trainable or not Toggle attention to be trainable or not
""" """
from axolotl.integrations.lolcats.linearize_attention import toggle_attention
toggle_attention(self.model, train=train) toggle_attention(self.model, train=train)
@@ -134,13 +136,206 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
""" """
Remove base attention after distillation Remove base attention after distillation
""" """
from axolotl.integrations.lolcats.linearize_attention import (
remove_base_attention,
)
remove_base_attention(self.model) remove_base_attention(self.model)
def convert_attention(
model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
# Get the layers to convert if provided
softmax_attns = attention_config.get("softmax_attentions", [])
# Get the attention to convert to
attention_type = attention_config.get("attention_type")
if attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else:
# Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: list[nn.Module], # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .attention import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .attention import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .attention import LolcatsLinearSlidingWindowAttention
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .attention import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .attention import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .attention import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .attention import LinearAttentionState
return LinearAttentionState()
def register_linear_llama(): def register_linear_llama():
""" """
Register Linear LLaMA model with the Transformers library. Register Linear LLaMA model with the Transformers library.

View File

@@ -1,218 +0,0 @@
"""
Convert attention to linear attention
Adapted from: https://github.com/HazyResearch/lolcats/blob/main/src/model/convert_model.py
@misc{zhang2024lolcatslowranklinearizinglarge,
title={LoLCATs: On Low-Rank Linearizing of Large Language Models},
author={Michael Zhang and Simran Arora and Rahul Chalamala and Alan Wu and Benjamin Spector and Aaryan Singhal and Krithik Ramesh and Christopher Ré},
year={2024},
eLOG.info={2410.10254},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10254},
}
"""
import logging
from functools import partial
from typing import Any
import torch.nn as nn
from tqdm import tqdm
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.integrations.lolcats.linearize_attention")
def convert_attention(
model: nn.Module,
attention_config: DictDefault,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
softmax_attns = []
if "softmax_attentions" in attention_config:
softmax_attns = attention_config["softmax_attentions"]
if attention_config.attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else: # Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: DictDefault,
layers: list[nn.Module], # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .linear_llama.attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .linear_llama.attention import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .linear_llama.attention import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .linear_llama.attention import LolcatsLinearSlidingWindowAttention
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .linear_llama.attention import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .linear_llama.attention import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_llama.attention import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .linear_llama.attention import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .linear_llama.attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .linear_llama.attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .linear_llama.attention import LinearAttentionState
return LinearAttentionState()