changes
This commit is contained in:
@@ -18,6 +18,7 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
|
||||
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
|
||||
LlamaDifferentialConfig,
|
||||
LlamaDifferentialForCausalLM,
|
||||
register_diff_attn,
|
||||
)
|
||||
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||
|
||||
@@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
register_diff_attn()
|
||||
debug_info = {}
|
||||
|
||||
# Load model and tokenizer
|
||||
@@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
+ Fore.RESET
|
||||
)
|
||||
try:
|
||||
model = LlamaDifferentialForCausalLM.from_llama(
|
||||
model,
|
||||
LlamaDifferentialConfig(
|
||||
**model.config.__dict__,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
),
|
||||
config = LlamaDifferentialConfig(
|
||||
**model.config.__dict__,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
)
|
||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Modeling for differential transformers."""
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
@@ -18,6 +19,8 @@ from .diff_attn import (
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaDifferentialConfig(LlamaConfig):
|
||||
"""Configuration class for Differential LLaMA model."""
|
||||
@@ -55,26 +58,85 @@ class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
|
||||
class LlamaDifferentialModel(LlamaModel):
|
||||
"""LlamaModel with differential attention."""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Handle attention implementation
|
||||
attn_impl = config._attn_implementation or "eager"
|
||||
if attn_impl in config._attn_implementations:
|
||||
attn_impl = config._attn_implementations[attn_impl]
|
||||
|
||||
# Validate attention implementation
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_impl not in valid_impls:
|
||||
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
||||
|
||||
# Replace standard attention with differential attention in each layer
|
||||
attn_classes = {
|
||||
"differential_eager": LlamaDifferentialAttention,
|
||||
"differential_sdpa": LlamaDifferentialSdpaAttention,
|
||||
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
|
||||
}
|
||||
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
attn_impl = config._attn_implementation or "eager"
|
||||
if attn_impl == "eager":
|
||||
layer.self_attn = LlamaDifferentialAttention(config, idx)
|
||||
elif attn_impl == "sdpa":
|
||||
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
|
||||
elif attn_impl == "flash_attention_2":
|
||||
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
|
||||
layer.self_attn = attn_class(config, idx)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@classmethod
|
||||
def _autoset_attn_implementation(
|
||||
cls, config, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
# Map standard types to differential types if mapping exists
|
||||
if attn_implementation in config._attn_implementations:
|
||||
config._attn_implementation = config._attn_implementations[
|
||||
attn_implementation
|
||||
]
|
||||
return config
|
||||
|
||||
# If no mapping, validate it's a valid differential type
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None
|
||||
cls,
|
||||
model: Union[LlamaModel, LlamaForCausalLM],
|
||||
config: Optional[LlamaDifferentialConfig] = None,
|
||||
) -> "LlamaDifferentialModel":
|
||||
"""Convert a LlamaModel to use differential attention."""
|
||||
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
if isinstance(model, LlamaForCausalLM):
|
||||
model = model.model
|
||||
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
logger.debug(f"Created config: {config}")
|
||||
|
||||
# Validate head counts if using split heads mode
|
||||
if config.split_heads:
|
||||
@@ -92,10 +154,14 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_model = cls(config)
|
||||
|
||||
# Copy all weights except attention
|
||||
logger.debug("Copying embeddings and norm")
|
||||
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
||||
new_model.norm.load_state_dict(model.norm.state_dict())
|
||||
|
||||
for new_layer, old_layer in zip(new_model.layers, model.layers):
|
||||
logger.debug("Copying layer weights")
|
||||
for layer_idx, (new_layer, old_layer) in enumerate(
|
||||
zip(new_model.layers, model.layers)
|
||||
):
|
||||
# Copy everything except attention weights
|
||||
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
|
||||
new_layer.input_layernorm.load_state_dict(
|
||||
@@ -109,7 +175,6 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.v_proj.load_state_dict(
|
||||
old_layer.self_attn.v_proj.state_dict()
|
||||
)
|
||||
print(old_layer.self_attn.o_proj.weight.shape)
|
||||
new_layer.self_attn.o_proj.load_state_dict(
|
||||
old_layer.self_attn.o_proj.state_dict()
|
||||
)
|
||||
@@ -119,6 +184,9 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
||||
|
||||
if not config.split_heads:
|
||||
logger.debug(
|
||||
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
|
||||
)
|
||||
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
@@ -127,6 +195,7 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
)
|
||||
|
||||
if config.zero_init:
|
||||
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
||||
# Zero out components as needed
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||
@@ -137,6 +206,7 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.lambda_k2.zero_()
|
||||
new_layer.self_attn.lambda_init.zero_()
|
||||
|
||||
logger.info("Conversion complete")
|
||||
return new_model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user