adding split_heads argument for retaining original (Q, K) dimensionanlity
This commit is contained in:
@@ -14,9 +14,7 @@ from transformers import HfArgumentParser
|
|||||||
|
|
||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.integrations.differential_transformer.convert import (
|
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
|
||||||
convert_to_differential_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path):
|
|||||||
|
|
||||||
# Convert attention
|
# Convert attention
|
||||||
LOG.info("Converting to differential attention...")
|
LOG.info("Converting to differential attention...")
|
||||||
|
if cli_args.split_heads and cli_args.zero_init:
|
||||||
|
LOG.warning(
|
||||||
|
Fore.YELLOW
|
||||||
|
+ "Warning: Using split_heads with zero_init is not recommended; "
|
||||||
|
+ "split_heads will preclude the effects of zero_init"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
model = convert_to_differential_attention(
|
model = convert_to_diff_attn(
|
||||||
model=model,
|
model=model,
|
||||||
zero_init=cli_args.zero_init,
|
zero_init=cli_args.zero_init,
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
|
split_heads=cli_args.split_heads,
|
||||||
)
|
)
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs:
|
|||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
zero_init: bool = field(default=False)
|
zero_init: bool = field(default=False)
|
||||||
sublayer_norm: bool = field(default=True)
|
sublayer_norm: bool = field(default=True)
|
||||||
|
split_heads: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|||||||
@@ -80,14 +80,18 @@ def copy_attention_weights(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_differential_attention(
|
def convert_to_diff_attn(
|
||||||
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
|
model: PreTrainedModel,
|
||||||
|
zero_init: bool = False,
|
||||||
|
sublayer_norm: bool = True,
|
||||||
|
split_heads: bool = True,
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||||
layer_idx = 0
|
layer_idx = 0
|
||||||
|
|
||||||
# Set sublayer norm as config on the model.
|
# Set sublayer norm as config on the model.
|
||||||
model.config.sublayer_norm = sublayer_norm
|
model.config.sublayer_norm = sublayer_norm
|
||||||
|
model.config.split_heads = split_heads
|
||||||
|
|
||||||
def convert_module(module):
|
def convert_module(module):
|
||||||
nonlocal layer_idx
|
nonlocal layer_idx
|
||||||
@@ -111,7 +115,8 @@ def convert_to_differential_attention(
|
|||||||
|
|
||||||
# Copy weights from old attention to new attention
|
# Copy weights from old attention to new attention
|
||||||
new_attention.to(child.q_proj.weight.device)
|
new_attention.to(child.q_proj.weight.device)
|
||||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
if not split_heads:
|
||||||
|
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||||
|
|
||||||
# Replace the layer
|
# Replace the layer
|
||||||
setattr(module, name, new_attention)
|
setattr(module, name, new_attention)
|
||||||
|
|||||||
@@ -70,26 +70,51 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.base_num_heads = config.num_attention_heads
|
self.base_num_heads = config.num_attention_heads
|
||||||
self.base_num_kv_heads = config.num_key_value_heads
|
self.base_num_kv_heads = config.num_key_value_heads
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
|
if config.split_heads:
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
||||||
|
else:
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
self.split_heads = config.split_heads
|
||||||
|
|
||||||
# For Q1 and Q2
|
if config.split_heads:
|
||||||
self.q_proj = nn.Linear(
|
# Split heads mode
|
||||||
self.hidden_size,
|
assert (
|
||||||
self.hidden_size * 2,
|
self.base_num_heads % 2 == 0
|
||||||
bias=False,
|
), "Number of heads must be even for splitting"
|
||||||
)
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
|
|
||||||
# For K1 and K2
|
# Single projections
|
||||||
self.k_proj = nn.Linear(
|
self.q_proj = nn.Linear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
self.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Double projection mode
|
||||||
|
self.heads_per_component = self.base_num_heads
|
||||||
|
|
||||||
|
# Double-sized projections
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Single V projection
|
# Single V projection
|
||||||
self.v_proj = nn.Linear(
|
self.v_proj = nn.Linear(
|
||||||
@@ -125,8 +150,14 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
|
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
subln_dim = 2 * self.head_dim
|
||||||
|
else:
|
||||||
|
subln_dim = self.head_dim
|
||||||
|
|
||||||
self.subln = (
|
self.subln = (
|
||||||
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
|
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
|
||||||
if sublayer_norm
|
if sublayer_norm
|
||||||
else nn.Identity()
|
else nn.Identity()
|
||||||
)
|
)
|
||||||
@@ -167,7 +198,10 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Reshape V
|
# Reshape V
|
||||||
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
if self.split_heads:
|
||||||
|
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
@@ -177,6 +211,10 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
cos, _ = cos.chunk(2, dim=2)
|
||||||
|
sin, _ = sin.chunk(2, dim=2)
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
@@ -192,8 +230,6 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||||
|
|
||||||
# Calculate attention scores for both parts
|
# Calculate attention scores for both parts
|
||||||
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
|
|
||||||
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
|
|
||||||
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||||
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
@@ -307,13 +343,18 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
|||||||
k1, k2 = kp.chunk(2, dim=-1)
|
k1, k2 = kp.chunk(2, dim=-1)
|
||||||
|
|
||||||
# Reshape Q1,Q2 for attention
|
# Reshape Q1,Q2 for attention
|
||||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Reshape K1,K2 for attention
|
# Reshape K1,K2 for attention
|
||||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Reshape V
|
# Reshape V
|
||||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
if self.split_heads:
|
||||||
|
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
@@ -323,6 +364,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
|
|||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
cos, _ = cos.chunk(2, dim=2)
|
||||||
|
sin, _ = sin.chunk(2, dim=2)
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
@@ -468,13 +513,18 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
|||||||
k1, k2 = kp.chunk(2, dim=-1)
|
k1, k2 = kp.chunk(2, dim=-1)
|
||||||
|
|
||||||
# Reshape Q1,Q2 for attention
|
# Reshape Q1,Q2 for attention
|
||||||
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
|
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Reshape K1,K2 for attention
|
# Reshape K1,K2 for attention
|
||||||
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Reshape V
|
# Reshape V
|
||||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
if self.split_heads:
|
||||||
|
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
@@ -484,6 +534,10 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
|||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
if self.split_heads:
|
||||||
|
cos, _ = cos.chunk(2, dim=2)
|
||||||
|
sin, _ = sin.chunk(2, dim=2)
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
@@ -506,20 +560,54 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
|
|||||||
|
|
||||||
# Calculate attention using Flash Attention
|
# Calculate attention using Flash Attention
|
||||||
dropout_p = self.attention_dropout if self.training else 0.0
|
dropout_p = self.attention_dropout if self.training else 0.0
|
||||||
attn1 = flash_attn_func(
|
if self.split_heads:
|
||||||
q1,
|
v1, v2 = v.chunk(2, dim=-1)
|
||||||
k1,
|
attn11 = flash_attn_func(
|
||||||
v,
|
q1,
|
||||||
dropout_p=dropout_p,
|
k1,
|
||||||
causal=True,
|
v1,
|
||||||
)
|
dropout_p=dropout_p,
|
||||||
attn2 = flash_attn_func(
|
causal=True,
|
||||||
q2,
|
)
|
||||||
k2,
|
attn12 = flash_attn_func(
|
||||||
v,
|
q1,
|
||||||
dropout_p=dropout_p,
|
k1,
|
||||||
causal=True,
|
v2,
|
||||||
)
|
dropout_p=dropout_p,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||||
|
|
||||||
|
attn21 = flash_attn_func(
|
||||||
|
q2,
|
||||||
|
k2,
|
||||||
|
v1,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn22 = flash_attn_func(
|
||||||
|
q2,
|
||||||
|
k2,
|
||||||
|
v2,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||||
|
else:
|
||||||
|
attn1 = flash_attn_func(
|
||||||
|
q1,
|
||||||
|
k1,
|
||||||
|
v,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn2 = flash_attn_func(
|
||||||
|
q2,
|
||||||
|
k2,
|
||||||
|
v,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
attn1 = attn1.transpose(1, 2)
|
attn1 = attn1.transpose(1, 2)
|
||||||
attn2 = attn2.transpose(1, 2)
|
attn2 = attn2.transpose(1, 2)
|
||||||
|
|||||||
@@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions(
|
|||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
assert (output_dir / "config.json").exists()
|
assert (output_dir / "config.json").exists()
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||||
|
)
|
||||||
|
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
base_config[attention] = True
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is False
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user