progress on modeling code
This commit is contained in:
@@ -26,34 +26,27 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||||
"""Run test inference and return generation time"""
|
"""Run test inference and return generation time"""
|
||||||
try:
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
|
||||||
inputs = {
|
|
||||||
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
**inputs,
|
**inputs,
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
)
|
)
|
||||||
elapsed = time() - start
|
elapsed = time() - start
|
||||||
|
|
||||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
LOG.info("Prompt: %s", prompt)
|
LOG.info("Prompt: %s", prompt)
|
||||||
LOG.info("Generated: %s", generated_text)
|
LOG.info("Generated: %s", generated_text)
|
||||||
LOG.info("Generation time: %.2fs", elapsed)
|
LOG.info("Generation time: %.2fs", elapsed)
|
||||||
|
|
||||||
return elapsed, generated_text
|
return elapsed, generated_text
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.error("Inference failed: %s", str(exc))
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||||
@@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
|||||||
+ Fore.RESET
|
+ Fore.RESET
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
LlamaDifferentialForCausalLM.from_llama(
|
model = LlamaDifferentialForCausalLM.from_llama(
|
||||||
model,
|
model,
|
||||||
LlamaDifferentialConfig(
|
LlamaDifferentialConfig(
|
||||||
**model.config.__dict__,
|
**model.config.__dict__,
|
||||||
@@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
|||||||
split_heads=cli_args.split_heads,
|
split_heads=cli_args.split_heads,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -7,4 +7,7 @@ plugins:
|
|||||||
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||||
|
|
||||||
diff_attention: true
|
diff_attention: true
|
||||||
|
diff_attn_zero_init: false
|
||||||
|
diff_attn_sublayer_norm: true
|
||||||
|
diff_attn_split_heads: false
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -8,18 +8,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DifferentialTransformerPlugin(BasePlugin):
|
class DifferentialTransformerPlugin(BasePlugin):
|
||||||
"""
|
"""Plugin for differential transformer integration with Axolotl."""
|
||||||
Plugin for differential transformer integration with Axolotl.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
|
||||||
"""Apply differential attention patch before model loading if enabled."""
|
|
||||||
if cfg.diff_attention:
|
|
||||||
from axolotl.monkeypatch.attention.differential import (
|
|
||||||
patch_llama_attention_classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_llama_attention_classes()
|
|
||||||
|
|||||||
@@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel):
|
|||||||
"""Input args for differential transformer."""
|
"""Input args for differential transformer."""
|
||||||
|
|
||||||
diff_attention: Optional[bool] = None
|
diff_attention: Optional[bool] = None
|
||||||
|
diff_attn_zero_init: Optional[bool] = None
|
||||||
|
diff_attn_sublayer_norm: Optional[bool] = None
|
||||||
|
diff_attn_split_heads: Optional[bool] = None
|
||||||
|
|||||||
@@ -1,135 +0,0 @@
|
|||||||
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
|
||||||
import logging
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .diff_attn import (
|
|
||||||
LlamaDifferentialAttention,
|
|
||||||
LlamaDifferentialFlashAttention2,
|
|
||||||
LlamaDifferentialSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ATTENTION_MAPPING = {
|
|
||||||
LlamaAttention: LlamaDifferentialAttention,
|
|
||||||
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
|
|
||||||
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def copy_attention_weights(
|
|
||||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
|
|
||||||
new_attn: Union[
|
|
||||||
LlamaDifferentialAttention,
|
|
||||||
LlamaDifferentialSdpaAttention,
|
|
||||||
LlamaDifferentialFlashAttention2,
|
|
||||||
],
|
|
||||||
zero_init: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Copy weights from old attention layer to new differential attention layer.
|
|
||||||
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
|
|
||||||
to original attention mechanism.
|
|
||||||
"""
|
|
||||||
# For Q projection (Q1 and Q2)
|
|
||||||
new_q = torch.empty_like(new_attn.q_proj.weight.data)
|
|
||||||
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
|
|
||||||
if zero_init:
|
|
||||||
new_q[new_attn.hidden_size :] = 0
|
|
||||||
else:
|
|
||||||
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
|
|
||||||
new_attn.q_proj.weight.data.copy_(new_q)
|
|
||||||
|
|
||||||
# For K projection (K1 and K2)
|
|
||||||
old_kv_size = old_attn.k_proj.weight.data.size(0)
|
|
||||||
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
|
||||||
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
|
||||||
if zero_init:
|
|
||||||
new_k[old_kv_size:] = 0
|
|
||||||
else:
|
|
||||||
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
|
|
||||||
new_attn.k_proj.weight.data.copy_(new_k)
|
|
||||||
|
|
||||||
# For V projection (single V)
|
|
||||||
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
|
|
||||||
|
|
||||||
# Output projection remains the same
|
|
||||||
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
|
|
||||||
|
|
||||||
# Zero out lambda parameters for exact equivalence
|
|
||||||
if zero_init:
|
|
||||||
nn.init.zeros_(new_attn.lambda_q1)
|
|
||||||
nn.init.zeros_(new_attn.lambda_k1)
|
|
||||||
nn.init.zeros_(new_attn.lambda_q2)
|
|
||||||
nn.init.zeros_(new_attn.lambda_k2)
|
|
||||||
nn.init.zeros_(new_attn.lambda_init)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Copied positive attention weights from %s to %s",
|
|
||||||
type(old_attn).__name__,
|
|
||||||
type(new_attn).__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_diff_attn(
|
|
||||||
model: PreTrainedModel,
|
|
||||||
zero_init: bool = False,
|
|
||||||
sublayer_norm: bool = True,
|
|
||||||
split_heads: bool = True,
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
|
||||||
layer_idx = 0
|
|
||||||
|
|
||||||
# Set sublayer norm as config on the model.
|
|
||||||
model.config.sublayer_norm = sublayer_norm
|
|
||||||
model.config.split_heads = split_heads
|
|
||||||
|
|
||||||
def convert_module(module):
|
|
||||||
nonlocal layer_idx
|
|
||||||
|
|
||||||
# Iterate through module children, convert any attn layers to diff attn
|
|
||||||
for name, child in module.named_children():
|
|
||||||
child_class_name = type(child).__name__
|
|
||||||
|
|
||||||
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
|
|
||||||
# Find matching attention class by name
|
|
||||||
for orig_class, diff_class in ATTENTION_MAPPING.items():
|
|
||||||
if orig_class.__name__ == child_class_name:
|
|
||||||
attention_class = diff_class
|
|
||||||
break
|
|
||||||
|
|
||||||
layer_type = type(child).__name__
|
|
||||||
logger.info(
|
|
||||||
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create new diff attn layer
|
|
||||||
new_attention = attention_class(
|
|
||||||
config=module.config if hasattr(module, "config") else model.config,
|
|
||||||
layer_idx=layer_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy weights from old attention to new attention
|
|
||||||
new_attention.to(child.q_proj.weight.device)
|
|
||||||
if not split_heads:
|
|
||||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
|
||||||
|
|
||||||
# Replace the layer
|
|
||||||
setattr(module, name, new_attention)
|
|
||||||
layer_idx += 1
|
|
||||||
elif len(list(child.children())) > 0:
|
|
||||||
convert_module(child)
|
|
||||||
|
|
||||||
convert_module(model)
|
|
||||||
logger.info(f"Converted {layer_idx} attention layers to differential attention")
|
|
||||||
|
|
||||||
return model
|
|
||||||
@@ -56,8 +56,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
"""Initialize configuration parameters."""
|
"""Initialize configuration parameters."""
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
self.base_num_heads = config.num_attention_heads
|
self.base_num_heads = config.num_attention_heads
|
||||||
self.base_num_kv_heads = config.num_key_value_heads
|
self.base_num_kv_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
@@ -66,15 +68,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
|
|
||||||
if config.split_heads:
|
if config.split_heads:
|
||||||
# Split heads mode - single projections
|
# Split heads mode - single projections
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||||
# implementation, which asserts `self.base_num_heads` is even
|
# implementation, which asserts `self.base_num_heads` is even
|
||||||
self.heads_per_component = self.base_num_heads // 2
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
|
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
||||||
self.value_head_dim = 2 * self.head_dim
|
self.value_head_dim = 2 * self.head_dim
|
||||||
else:
|
else:
|
||||||
# Double projection mode
|
# Double projection mode
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
self.heads_per_component = self.base_num_heads
|
self.heads_per_component = self.base_num_heads
|
||||||
|
self.kv_heads_per_component = self.base_num_kv_heads
|
||||||
self.value_head_dim = self.head_dim
|
self.value_head_dim = self.head_dim
|
||||||
|
|
||||||
def _init_projections(self):
|
def _init_projections(self):
|
||||||
@@ -90,14 +92,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
||||||
)
|
)
|
||||||
|
|
||||||
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
|
self.q_proj = nn.Linear(
|
||||||
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
|
self.hidden_size, q_out_dim, bias=self.config.attention_bias
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
self.hidden_size, k_out_dim, bias=self.config.attention_bias
|
||||||
|
)
|
||||||
self.v_proj = nn.Linear(
|
self.v_proj = nn.Linear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||||
bias=False,
|
bias=self.config.attention_bias,
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.base_num_heads * self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=self.config.attention_bias,
|
||||||
)
|
)
|
||||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
||||||
|
|
||||||
def _init_differential_params(self):
|
def _init_differential_params(self):
|
||||||
"""Initialize differential attention parameters."""
|
"""Initialize differential attention parameters."""
|
||||||
@@ -145,13 +155,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
)
|
)
|
||||||
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
)
|
)
|
||||||
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
)
|
)
|
||||||
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
|
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -184,10 +194,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||||
k1, k2 = k.unbind(dim=1)
|
k1, k2 = k.unbind(dim=1)
|
||||||
|
|
||||||
# Repeat KV heads
|
# Repeat KV heads to match number of query heads
|
||||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
v = repeat_kv(v, self.num_key_value_groups)
|
||||||
|
|
||||||
return k1, k2, v
|
return k1, k2, v
|
||||||
|
|
||||||
|
|||||||
@@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Replace standard attention with differential attention in each layer
|
# Replace standard attention with differential attention in each layer
|
||||||
for layer in self.layers:
|
for idx, layer in enumerate(self.layers):
|
||||||
attn_impl = config._attn_implementation or "eager"
|
attn_impl = config._attn_implementation or "eager"
|
||||||
if attn_impl == "eager":
|
if attn_impl == "eager":
|
||||||
layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx)
|
layer.self_attn = LlamaDifferentialAttention(config, idx)
|
||||||
elif attn_impl == "sdpa":
|
elif attn_impl == "sdpa":
|
||||||
layer.self_attn = LlamaDifferentialSdpaAttention(
|
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
|
||||||
config, layer.layer_idx
|
|
||||||
)
|
|
||||||
elif attn_impl == "flash_attention_2":
|
elif attn_impl == "flash_attention_2":
|
||||||
layer.self_attn = LlamaDifferentialFlashAttention2(
|
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
|
||||||
config, layer.layer_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_llama(
|
def from_llama(
|
||||||
@@ -78,7 +75,21 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
if config is None:
|
if config is None:
|
||||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||||
|
|
||||||
|
# Validate head counts if using split heads mode
|
||||||
|
if config.split_heads:
|
||||||
|
if config.num_attention_heads % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||||
|
"when using split_heads=True"
|
||||||
|
)
|
||||||
|
if config.num_key_value_heads % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||||
|
"when using split_heads=True"
|
||||||
|
)
|
||||||
|
|
||||||
new_model = cls(config)
|
new_model = cls(config)
|
||||||
|
|
||||||
# Copy all weights except attention
|
# Copy all weights except attention
|
||||||
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
||||||
new_model.norm.load_state_dict(model.norm.state_dict())
|
new_model.norm.load_state_dict(model.norm.state_dict())
|
||||||
@@ -97,34 +108,28 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_layer.self_attn.v_proj.load_state_dict(
|
new_layer.self_attn.v_proj.load_state_dict(
|
||||||
old_layer.self_attn.v_proj.state_dict()
|
old_layer.self_attn.v_proj.state_dict()
|
||||||
)
|
)
|
||||||
|
print(old_layer.self_attn.o_proj.weight.shape)
|
||||||
new_layer.self_attn.o_proj.load_state_dict(
|
new_layer.self_attn.o_proj.load_state_dict(
|
||||||
old_layer.self_attn.o_proj.state_dict()
|
old_layer.self_attn.o_proj.state_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.split_heads:
|
# Get the original projection sizes
|
||||||
new_layer.self_attn.q_proj.weight.data.copy_(
|
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
|
||||||
|
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
||||||
|
|
||||||
|
if not config.split_heads:
|
||||||
|
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
||||||
old_layer.self_attn.q_proj.weight.data
|
old_layer.self_attn.q_proj.weight.data
|
||||||
)
|
)
|
||||||
new_layer.self_attn.k_proj.weight.data.copy_(
|
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
|
||||||
old_layer.self_attn.k_proj.weight.data
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_(
|
|
||||||
old_layer.self_attn.q_proj.weight.data
|
|
||||||
)
|
|
||||||
new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_(
|
|
||||||
old_layer.self_attn.k_proj.weight.data
|
old_layer.self_attn.k_proj.weight.data
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.zero_init:
|
if config.zero_init:
|
||||||
# Zero out components as needed
|
# Zero out components as needed
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_layer.self_attn.q_proj.weight.data[
|
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||||
config.hidden_size :
|
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
||||||
].zero_()
|
|
||||||
new_layer.self_attn.k_proj.weight.data[
|
|
||||||
config.hidden_size :
|
|
||||||
].zero_()
|
|
||||||
new_layer.self_attn.lambda_q1.zero_()
|
new_layer.self_attn.lambda_q1.zero_()
|
||||||
new_layer.self_attn.lambda_k1.zero_()
|
new_layer.self_attn.lambda_k1.zero_()
|
||||||
new_layer.self_attn.lambda_q2.zero_()
|
new_layer.self_attn.lambda_q2.zero_()
|
||||||
@@ -149,7 +154,21 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
|||||||
if config is None:
|
if config is None:
|
||||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||||
|
|
||||||
|
# Validate head counts if using split heads mode
|
||||||
|
if config.split_heads:
|
||||||
|
if config.num_attention_heads % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||||
|
"when using split_heads=True"
|
||||||
|
)
|
||||||
|
if config.num_key_value_heads % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||||
|
"when using split_heads=True"
|
||||||
|
)
|
||||||
|
|
||||||
new_model = cls(config)
|
new_model = cls(config)
|
||||||
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
|
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
|
||||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|||||||
@@ -710,11 +710,30 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
|
# if self.cfg.flash_attention:
|
||||||
|
# if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
# "flash_attention_2"
|
||||||
|
# )
|
||||||
|
# elif self.cfg.sdp_attention:
|
||||||
|
# self.model_kwargs["attn_implementation"] = "sdpa"
|
||||||
|
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
# "sdpa"
|
||||||
|
# )
|
||||||
|
# elif self.cfg.eager_attention:
|
||||||
|
# self.model_kwargs["attn_implementation"] = "eager"
|
||||||
|
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
# "eager"
|
||||||
|
# )
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self.cfg.differentiaion:
|
if self.cfg.diff_attention:
|
||||||
self.model_kwargs[
|
self.model_kwargs[
|
||||||
"attn_implementation"
|
"attn_implementation"
|
||||||
] = "differential_flash_attention_2"
|
] = "differential_flash_attention_2"
|
||||||
|
|||||||
@@ -15,135 +15,133 @@ from axolotl.cli.main import cli
|
|||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("base_config", "cli_runner")
|
def test_cli_validation(cli_runner):
|
||||||
class TestDiffTransformer:
|
# Test missing config file
|
||||||
"""Tests for convert-diff-transformer CLI command"""
|
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||||
|
|
||||||
def test_cli_validation(self, cli_runner):
|
# Test non-existent config file
|
||||||
# Test missing config file
|
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||||
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
assert result.exit_code != 0
|
||||||
assert result.exit_code != 0
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
|
||||||
|
|
||||||
# Test non-existent config file
|
|
||||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
|
||||||
|
|
||||||
def test_basic_execution(self, cli_runner, tmp_path: Path, base_config):
|
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
yaml.dump(base_config, file)
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||||
) as mock_do_cli:
|
) as mock_do_cli:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
|
||||||
cli, ["convert-diff-transformer", str(config_path)]
|
assert result.exit_code == 0
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
mock_do_cli.assert_called_once()
|
mock_do_cli.assert_called_once()
|
||||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
|
||||||
def test_conversion_cli_basic(self, tmp_path: Path, base_config):
|
|
||||||
output_dir = tmp_path / "converted"
|
|
||||||
base_config["output_dir"] = str(output_dir)
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
output_dir = tmp_path / "converted"
|
||||||
yaml.dump(base_config, file)
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
config_path = tmp_path / "config.yml"
|
||||||
cli_args = ConvertDiffTransformerCliArgs()
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
assert not debug_info
|
cfg = load_cfg(str(config_path))
|
||||||
assert (output_dir / "model.safetensors").exists()
|
cli_args = ConvertDiffTransformerCliArgs()
|
||||||
assert (output_dir / "config.json").exists()
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
|
||||||
|
|
||||||
def test_conversion_cli_debug(self, tmp_path: Path, base_config):
|
assert not debug_info
|
||||||
output_dir = tmp_path / "converted"
|
assert (output_dir / "model.safetensors").exists()
|
||||||
base_config["output_dir"] = str(output_dir)
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
|
||||||
yaml.dump(base_config, file)
|
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
output_dir = tmp_path / "converted"
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
assert not debug_info["generations_match"]
|
config_path = tmp_path / "config.yml"
|
||||||
assert not debug_info["match_expected"]
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
assert (output_dir / "model.safetensors").exists()
|
yaml.dump(base_config, file)
|
||||||
assert (output_dir / "config.json").exists()
|
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
|
||||||
|
|
||||||
def test_conversion_cli_reproduce(self, tmp_path: Path, base_config):
|
cfg = load_cfg(str(config_path))
|
||||||
output_dir = tmp_path / "converted"
|
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||||
base_config["output_dir"] = str(output_dir)
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
assert not debug_info["generations_match"]
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
assert not debug_info["match_expected"]
|
||||||
yaml.dump(base_config, file)
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs(
|
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
|
||||||
)
|
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||||
assert (output_dir / "model.safetensors").exists()
|
output_dir = tmp_path / "converted"
|
||||||
assert (output_dir / "config.json").exists()
|
base_config["output_dir"] = str(output_dir)
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
config_path = tmp_path / "config.yml"
|
||||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
)
|
)
|
||||||
def test_conversion_cli_repoduce_attentions(
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
self, tmp_path: Path, base_config, attention: Optional[str]
|
|
||||||
):
|
|
||||||
output_dir = tmp_path / "converted"
|
|
||||||
base_config["output_dir"] = str(output_dir)
|
|
||||||
base_config[attention] = True
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
assert debug_info["generations_match"] is True
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
assert (output_dir / "model.safetensors").exists()
|
||||||
yaml.dump(base_config, file)
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs(
|
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
|
||||||
)
|
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
@pytest.mark.parametrize(
|
||||||
assert (output_dir / "model.safetensors").exists()
|
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||||
assert (output_dir / "config.json").exists()
|
)
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
def test_conversion_cli_repoduce_attentions(
|
||||||
|
tmp_path: Path, base_config, attention: Optional[str]
|
||||||
|
):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
base_config[attention] = True
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
config_path = tmp_path / "config.yml"
|
||||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
)
|
)
|
||||||
def test_conversion_cli_split_heads(
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
self, tmp_path: Path, base_config, attention: str
|
|
||||||
):
|
|
||||||
output_dir = tmp_path / "converted"
|
|
||||||
base_config["output_dir"] = str(output_dir)
|
|
||||||
base_config[attention] = True
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
assert debug_info["generations_match"] is True
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
assert (output_dir / "model.safetensors").exists()
|
||||||
yaml.dump(base_config, file)
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
|
||||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert debug_info["generations_match"] is False
|
@pytest.mark.parametrize(
|
||||||
assert (output_dir / "model.safetensors").exists()
|
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||||
assert (output_dir / "config.json").exists()
|
)
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
base_config[attention] = True
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||||
|
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is False
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user