adding split_heads argument for retaining original (Q, K) dimensionanlity
This commit is contained in:
@@ -14,9 +14,7 @@ from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.differential_transformer.convert import (
|
||||
convert_to_differential_attention,
|
||||
)
|
||||
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
||||
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
if cli_args.split_heads and cli_args.zero_init:
|
||||
LOG.warning(
|
||||
Fore.YELLOW
|
||||
+ "Warning: Using split_heads with zero_init is not recommended; "
|
||||
+ "split_heads will preclude the effects of zero_init"
|
||||
+ Fore.RESET
|
||||
)
|
||||
try:
|
||||
model = convert_to_differential_attention(
|
||||
model = convert_to_diff_attn(
|
||||
model=model,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs:
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
split_heads: bool = field(default=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -80,14 +80,18 @@ def copy_attention_weights(
|
||||
)
|
||||
|
||||
|
||||
def convert_to_differential_attention(
|
||||
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
|
||||
def convert_to_diff_attn(
|
||||
model: PreTrainedModel,
|
||||
zero_init: bool = False,
|
||||
sublayer_norm: bool = True,
|
||||
split_heads: bool = True,
|
||||
) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
layer_idx = 0
|
||||
|
||||
# Set sublayer norm as config on the model.
|
||||
model.config.sublayer_norm = sublayer_norm
|
||||
model.config.split_heads = split_heads
|
||||
|
||||
def convert_module(module):
|
||||
nonlocal layer_idx
|
||||
@@ -111,7 +115,8 @@ def convert_to_differential_attention(
|
||||
|
||||
# Copy weights from old attention to new attention
|
||||
new_attention.to(child.q_proj.weight.device)
|
||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||
if not split_heads:
|
||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||
|
||||
# Replace the layer
|
||||
setattr(module, name, new_attention)
|
||||
|
||||
@@ -70,26 +70,51 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
if config.split_heads:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
||||
else:
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.split_heads = config.split_heads
|
||||
|
||||
# For Q1 and Q2
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
if config.split_heads:
|
||||
# Split heads mode
|
||||
assert (
|
||||
self.base_num_heads % 2 == 0
|
||||
), "Number of heads must be even for splitting"
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
|
||||
# For K1 and K2
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||
bias=False,
|
||||
)
|
||||
# Single projections
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# Double projection mode
|
||||
self.heads_per_component = self.base_num_heads
|
||||
|
||||
# Double-sized projections
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Single V projection
|
||||
self.v_proj = nn.Linear(
|
||||
@@ -125,8 +150,14 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||
|
||||
if self.split_heads:
|
||||
subln_dim = 2 * self.head_dim
|
||||
else:
|
||||
subln_dim = self.head_dim
|
||||
|
||||
self.subln = (
|
||||
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
|
||||
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
|
||||
if sublayer_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
@@ -167,7 +198,10 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
@@ -177,6 +211,10 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
@@ -192,8 +230,6 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
|
||||
# Calculate attention scores for both parts
|
||||
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
|
||||
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
|
||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||
|
||||
@@ -307,13 +343,18 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
@@ -323,6 +364,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
@@ -468,13 +513,18 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
||||
k1, k2 = kp.chunk(2, dim=-1)
|
||||
|
||||
# Reshape Q1,Q2 for attention
|
||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape K1,K2 for attention
|
||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Reshape V
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
if self.split_heads:
|
||||
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||
else:
|
||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is None:
|
||||
@@ -484,6 +534,10 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
if self.split_heads:
|
||||
cos, _ = cos.chunk(2, dim=2)
|
||||
sin, _ = sin.chunk(2, dim=2)
|
||||
|
||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||
|
||||
@@ -506,20 +560,54 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
||||
|
||||
# Calculate attention using Flash Attention
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
attn1 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
if self.split_heads:
|
||||
v1, v2 = v.chunk(2, dim=-1)
|
||||
attn11 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v1,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn12 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v2,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v1,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn22 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v2,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
else:
|
||||
attn1 = flash_attn_func(
|
||||
q1,
|
||||
k1,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
attn2 = flash_attn_func(
|
||||
q2,
|
||||
k2,
|
||||
v,
|
||||
dropout_p=dropout_p,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
attn1 = attn1.transpose(1, 2)
|
||||
attn2 = attn2.transpose(1, 2)
|
||||
|
||||
@@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions(
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
)
|
||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is False
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
Reference in New Issue
Block a user