changes
This commit is contained in:
@@ -18,6 +18,7 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
|
|||||||
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
|
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
|
||||||
LlamaDifferentialConfig,
|
LlamaDifferentialConfig,
|
||||||
LlamaDifferentialForCausalLM,
|
LlamaDifferentialForCausalLM,
|
||||||
|
register_diff_attn,
|
||||||
)
|
)
|
||||||
from axolotl.utils.yaml import dump_yaml_preserved_order
|
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
|||||||
|
|
||||||
|
|
||||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||||
|
register_diff_attn()
|
||||||
debug_info = {}
|
debug_info = {}
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
@@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
|||||||
+ Fore.RESET
|
+ Fore.RESET
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
model = LlamaDifferentialForCausalLM.from_llama(
|
config = LlamaDifferentialConfig(
|
||||||
model,
|
**model.config.__dict__,
|
||||||
LlamaDifferentialConfig(
|
zero_init=cli_args.zero_init,
|
||||||
**model.config.__dict__,
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
zero_init=cli_args.zero_init,
|
split_heads=cli_args.split_heads,
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
|
||||||
split_heads=cli_args.split_heads,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Modeling for differential transformers."""
|
"""Modeling for differential transformers."""
|
||||||
|
|
||||||
from typing import Optional
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||||
@@ -18,6 +19,8 @@ from .diff_attn import (
|
|||||||
LlamaDifferentialSdpaAttention,
|
LlamaDifferentialSdpaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialConfig(LlamaConfig):
|
class LlamaDifferentialConfig(LlamaConfig):
|
||||||
"""Configuration class for Differential LLaMA model."""
|
"""Configuration class for Differential LLaMA model."""
|
||||||
@@ -55,26 +58,85 @@ class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
|
|||||||
class LlamaDifferentialModel(LlamaModel):
|
class LlamaDifferentialModel(LlamaModel):
|
||||||
"""LlamaModel with differential attention."""
|
"""LlamaModel with differential attention."""
|
||||||
|
|
||||||
|
config_class = LlamaDifferentialConfig
|
||||||
|
base_model_prefix = "llama_differential"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Handle attention implementation
|
||||||
|
attn_impl = config._attn_implementation or "eager"
|
||||||
|
if attn_impl in config._attn_implementations:
|
||||||
|
attn_impl = config._attn_implementations[attn_impl]
|
||||||
|
|
||||||
|
# Validate attention implementation
|
||||||
|
valid_impls = [
|
||||||
|
None,
|
||||||
|
"differential_eager",
|
||||||
|
"differential_sdpa",
|
||||||
|
"differential_flash_attention_2",
|
||||||
|
]
|
||||||
|
if attn_impl not in valid_impls:
|
||||||
|
raise ValueError(f"Invalid attention implementation: {attn_impl}")
|
||||||
|
|
||||||
# Replace standard attention with differential attention in each layer
|
# Replace standard attention with differential attention in each layer
|
||||||
|
attn_classes = {
|
||||||
|
"differential_eager": LlamaDifferentialAttention,
|
||||||
|
"differential_sdpa": LlamaDifferentialSdpaAttention,
|
||||||
|
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
|
||||||
|
}
|
||||||
|
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
|
||||||
|
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
attn_impl = config._attn_implementation or "eager"
|
layer.self_attn = attn_class(config, idx)
|
||||||
if attn_impl == "eager":
|
|
||||||
layer.self_attn = LlamaDifferentialAttention(config, idx)
|
# pylint: disable=protected-access
|
||||||
elif attn_impl == "sdpa":
|
@classmethod
|
||||||
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
|
def _autoset_attn_implementation(
|
||||||
elif attn_impl == "flash_attention_2":
|
cls, config, **kwargs
|
||||||
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
|
): # pylint: disable=unused-argument
|
||||||
|
config._attn_implementation_autoset = True
|
||||||
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
|
# Map standard types to differential types if mapping exists
|
||||||
|
if attn_implementation in config._attn_implementations:
|
||||||
|
config._attn_implementation = config._attn_implementations[
|
||||||
|
attn_implementation
|
||||||
|
]
|
||||||
|
return config
|
||||||
|
|
||||||
|
# If no mapping, validate it's a valid differential type
|
||||||
|
valid_impls = [
|
||||||
|
None,
|
||||||
|
"differential_eager",
|
||||||
|
"differential_sdpa",
|
||||||
|
"differential_flash_attention_2",
|
||||||
|
]
|
||||||
|
if attn_implementation not in valid_impls:
|
||||||
|
message = (
|
||||||
|
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||||
|
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||||
|
)
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(
|
def from_llama(
|
||||||
cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None
|
cls,
|
||||||
|
model: Union[LlamaModel, LlamaForCausalLM],
|
||||||
|
config: Optional[LlamaDifferentialConfig] = None,
|
||||||
) -> "LlamaDifferentialModel":
|
) -> "LlamaDifferentialModel":
|
||||||
"""Convert a LlamaModel to use differential attention."""
|
"""Convert a LlamaModel to use differential attention."""
|
||||||
|
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
|
||||||
|
|
||||||
|
# Handle LlamaForCausalLM
|
||||||
|
if isinstance(model, LlamaForCausalLM):
|
||||||
|
model = model.model
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||||
|
logger.debug(f"Created config: {config}")
|
||||||
|
|
||||||
# Validate head counts if using split heads mode
|
# Validate head counts if using split heads mode
|
||||||
if config.split_heads:
|
if config.split_heads:
|
||||||
@@ -92,10 +154,14 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_model = cls(config)
|
new_model = cls(config)
|
||||||
|
|
||||||
# Copy all weights except attention
|
# Copy all weights except attention
|
||||||
|
logger.debug("Copying embeddings and norm")
|
||||||
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
||||||
new_model.norm.load_state_dict(model.norm.state_dict())
|
new_model.norm.load_state_dict(model.norm.state_dict())
|
||||||
|
|
||||||
for new_layer, old_layer in zip(new_model.layers, model.layers):
|
logger.debug("Copying layer weights")
|
||||||
|
for layer_idx, (new_layer, old_layer) in enumerate(
|
||||||
|
zip(new_model.layers, model.layers)
|
||||||
|
):
|
||||||
# Copy everything except attention weights
|
# Copy everything except attention weights
|
||||||
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
|
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
|
||||||
new_layer.input_layernorm.load_state_dict(
|
new_layer.input_layernorm.load_state_dict(
|
||||||
@@ -109,7 +175,6 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_layer.self_attn.v_proj.load_state_dict(
|
new_layer.self_attn.v_proj.load_state_dict(
|
||||||
old_layer.self_attn.v_proj.state_dict()
|
old_layer.self_attn.v_proj.state_dict()
|
||||||
)
|
)
|
||||||
print(old_layer.self_attn.o_proj.weight.shape)
|
|
||||||
new_layer.self_attn.o_proj.load_state_dict(
|
new_layer.self_attn.o_proj.load_state_dict(
|
||||||
old_layer.self_attn.o_proj.state_dict()
|
old_layer.self_attn.o_proj.state_dict()
|
||||||
)
|
)
|
||||||
@@ -119,6 +184,9 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
||||||
|
|
||||||
if not config.split_heads:
|
if not config.split_heads:
|
||||||
|
logger.debug(
|
||||||
|
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
|
||||||
|
)
|
||||||
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
||||||
old_layer.self_attn.q_proj.weight.data
|
old_layer.self_attn.q_proj.weight.data
|
||||||
)
|
)
|
||||||
@@ -127,6 +195,7 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.zero_init:
|
if config.zero_init:
|
||||||
|
logger.debug(f"Layer {layer_idx}: Zero initializing")
|
||||||
# Zero out components as needed
|
# Zero out components as needed
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||||
@@ -137,6 +206,7 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_layer.self_attn.lambda_k2.zero_()
|
new_layer.self_attn.lambda_k2.zero_()
|
||||||
new_layer.self_attn.lambda_init.zero_()
|
new_layer.self_attn.lambda_init.zero_()
|
||||||
|
|
||||||
|
logger.info("Conversion complete")
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user