This commit is contained in:
Dan Saunders
2024-12-27 21:24:16 +00:00
parent 3bc568eb27
commit 78e0ec0aa5
2 changed files with 89 additions and 19 deletions

View File

@@ -18,6 +18,7 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
register_diff_attn,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
@@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
def convert_diff_transformer(cfg, cli_args, config_path):
register_diff_attn()
debug_info = {}
# Load model and tokenizer
@@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET
)
try:
model = LlamaDifferentialForCausalLM.from_llama(
model,
LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
),
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))

View File

@@ -1,6 +1,7 @@
"""Modeling for differential transformers."""
from typing import Optional
import logging
from typing import Optional, Union
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
@@ -18,6 +19,8 @@ from .diff_attn import (
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model."""
@@ -55,26 +58,85 @@ class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention."""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config):
super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers):
attn_impl = config._attn_implementation or "eager"
if attn_impl == "eager":
layer.self_attn = LlamaDifferentialAttention(config, idx)
elif attn_impl == "sdpa":
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
elif attn_impl == "flash_attention_2":
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
layer.self_attn = attn_class(config, idx)
# pylint: disable=protected-access
@classmethod
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None
cls,
model: Union[LlamaModel, LlamaForCausalLM],
config: Optional[LlamaDifferentialConfig] = None,
) -> "LlamaDifferentialModel":
"""Convert a LlamaModel to use differential attention."""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode
if config.split_heads:
@@ -92,10 +154,14 @@ class LlamaDifferentialModel(LlamaModel):
new_model = cls(config)
# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
for new_layer, old_layer in zip(new_model.layers, model.layers):
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
@@ -109,7 +175,6 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
print(old_layer.self_attn.o_proj.weight.shape)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
@@ -119,6 +184,9 @@ class LlamaDifferentialModel(LlamaModel):
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
@@ -127,6 +195,7 @@ class LlamaDifferentialModel(LlamaModel):
)
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
# Zero out components as needed
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
@@ -137,6 +206,7 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
logger.info("Conversion complete")
return new_model