This commit is contained in:
Dan Saunders
2024-12-27 21:24:16 +00:00
parent 3bc568eb27
commit 78e0ec0aa5
2 changed files with 89 additions and 19 deletions

View File

@@ -18,6 +18,7 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
from axolotl.integrations.diff_transformer.modeling_diff_attn import ( from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig, LlamaDifferentialConfig,
LlamaDifferentialForCausalLM, LlamaDifferentialForCausalLM,
register_diff_attn,
) )
from axolotl.utils.yaml import dump_yaml_preserved_order from axolotl.utils.yaml import dump_yaml_preserved_order
@@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
def convert_diff_transformer(cfg, cli_args, config_path): def convert_diff_transformer(cfg, cli_args, config_path):
register_diff_attn()
debug_info = {} debug_info = {}
# Load model and tokenizer # Load model and tokenizer
@@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET + Fore.RESET
) )
try: try:
model = LlamaDifferentialForCausalLM.from_llama( config = LlamaDifferentialConfig(
model, **model.config.__dict__,
LlamaDifferentialConfig( zero_init=cli_args.zero_init,
**model.config.__dict__, sublayer_norm=cli_args.sublayer_norm,
zero_init=cli_args.zero_init, split_heads=cli_args.split_heads,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
),
) )
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype) model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc: except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))

View File

@@ -1,6 +1,7 @@
"""Modeling for differential transformers.""" """Modeling for differential transformers."""
from typing import Optional import logging
from typing import Optional, Union
import torch import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
@@ -18,6 +19,8 @@ from .diff_attn import (
LlamaDifferentialSdpaAttention, LlamaDifferentialSdpaAttention,
) )
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig): class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model.""" """Configuration class for Differential LLaMA model."""
@@ -55,26 +58,85 @@ class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
class LlamaDifferentialModel(LlamaModel): class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention.""" """LlamaModel with differential attention."""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer # Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
attn_impl = config._attn_implementation or "eager" layer.self_attn = attn_class(config, idx)
if attn_impl == "eager":
layer.self_attn = LlamaDifferentialAttention(config, idx) # pylint: disable=protected-access
elif attn_impl == "sdpa": @classmethod
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx) def _autoset_attn_implementation(
elif attn_impl == "flash_attention_2": cls, config, **kwargs
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx) ): # pylint: disable=unused-argument
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod @classmethod
def from_llama( def from_llama(
cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None cls,
model: Union[LlamaModel, LlamaForCausalLM],
config: Optional[LlamaDifferentialConfig] = None,
) -> "LlamaDifferentialModel": ) -> "LlamaDifferentialModel":
"""Convert a LlamaModel to use differential attention.""" """Convert a LlamaModel to use differential attention."""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None: if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__) config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode # Validate head counts if using split heads mode
if config.split_heads: if config.split_heads:
@@ -92,10 +154,14 @@ class LlamaDifferentialModel(LlamaModel):
new_model = cls(config) new_model = cls(config)
# Copy all weights except attention # Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict()) new_model.norm.load_state_dict(model.norm.state_dict())
for new_layer, old_layer in zip(new_model.layers, model.layers): logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights # Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict( new_layer.input_layernorm.load_state_dict(
@@ -109,7 +175,6 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.v_proj.load_state_dict( new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict() old_layer.self_attn.v_proj.state_dict()
) )
print(old_layer.self_attn.o_proj.weight.shape)
new_layer.self_attn.o_proj.load_state_dict( new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict() old_layer.self_attn.o_proj.state_dict()
) )
@@ -119,6 +184,9 @@ class LlamaDifferentialModel(LlamaModel):
old_k_size = old_layer.self_attn.k_proj.weight.size(0) old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads: if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_( new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data old_layer.self_attn.q_proj.weight.data
) )
@@ -127,6 +195,7 @@ class LlamaDifferentialModel(LlamaModel):
) )
if config.zero_init: if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
# Zero out components as needed # Zero out components as needed
with torch.no_grad(): with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
@@ -137,6 +206,7 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_() new_layer.self_attn.lambda_init.zero_()
logger.info("Conversion complete")
return new_model return new_model