progress on modeling code

This commit is contained in:
Dan Saunders
2024-12-24 05:30:46 +00:00
parent 4ff3328e66
commit eb6611d55f
9 changed files with 216 additions and 316 deletions

View File

@@ -26,34 +26,27 @@ LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
return elapsed, generated_text
def convert_diff_transformer(cfg, cli_args, config_path):
@@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET
)
try:
LlamaDifferentialForCausalLM.from_llama(
model = LlamaDifferentialForCausalLM.from_llama(
model,
LlamaDifferentialConfig(
**model.config.__dict__,
@@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
split_heads=cli_args.split_heads,
),
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

View File

@@ -7,4 +7,7 @@ plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
diff_attn_zero_init: false
diff_attn_sublayer_norm: true
diff_attn_split_heads: false
```

View File

@@ -8,18 +8,7 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""
Plugin for differential transformer integration with Axolotl.
"""
"""Plugin for differential transformer integration with Axolotl."""
def get_input_args(self):
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()

View File

@@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
diff_attention: Optional[bool] = None
diff_attn_zero_init: Optional[bool] = None
diff_attn_sublayer_norm: Optional[bool] = None
diff_attn_split_heads: Optional[bool] = None

View File

@@ -1,135 +0,0 @@
"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0)
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
nn.init.zeros_(new_attn.lambda_init)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
child_class_name = type(child).__name__
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
# Find matching attention class by name
for orig_class, diff_class in ATTENTION_MAPPING.items():
if orig_class.__name__ == child_class_name:
attention_class = diff_class
break
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(model)
logger.info(f"Converted {layer_idx} attention layers to differential attention")
return model

View File

@@ -56,8 +56,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
@@ -66,15 +68,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
# Double projection mode
self.head_dim = config.hidden_size // config.num_attention_heads
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self):
@@ -90,14 +92,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
self.q_proj = nn.Linear(
self.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.hidden_size,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_differential_params(self):
"""Initialize differential attention parameters."""
@@ -145,13 +155,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
1, 2
)
@@ -184,10 +194,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
# Repeat KV heads to match number of query heads
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return k1, k2, v

View File

@@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel):
def __init__(self, config):
super().__init__(config)
# Replace standard attention with differential attention in each layer
for layer in self.layers:
for idx, layer in enumerate(self.layers):
attn_impl = config._attn_implementation or "eager"
if attn_impl == "eager":
layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx)
layer.self_attn = LlamaDifferentialAttention(config, idx)
elif attn_impl == "sdpa":
layer.self_attn = LlamaDifferentialSdpaAttention(
config, layer.layer_idx
)
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
elif attn_impl == "flash_attention_2":
layer.self_attn = LlamaDifferentialFlashAttention2(
config, layer.layer_idx
)
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
@classmethod
def from_llama(
@@ -78,7 +75,21 @@ class LlamaDifferentialModel(LlamaModel):
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
# Copy all weights except attention
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
@@ -97,34 +108,28 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
print(old_layer.self_attn.o_proj.weight.shape)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
if config.split_heads:
new_layer.self_attn.q_proj.weight.data.copy_(
# Get the original projection sizes
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data.copy_(
old_layer.self_attn.k_proj.weight.data
)
else:
new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_(
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
if config.zero_init:
# Zero out components as needed
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[
config.hidden_size :
].zero_()
new_layer.self_attn.k_proj.weight.data[
config.hidden_size :
].zero_()
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_()
@@ -149,7 +154,21 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model

View File

@@ -710,11 +710,30 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
# if self.cfg.flash_attention:
# if not self.cfg.sample_packing and self.cfg.s2_attention:
# pass
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "flash_attention_2"
# )
# elif self.cfg.sdp_attention:
# self.model_kwargs["attn_implementation"] = "sdpa"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "sdpa"
# )
# elif self.cfg.eager_attention:
# self.model_kwargs["attn_implementation"] = "eager"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "eager"
# )
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
if self.cfg.differentiaion:
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"

View File

@@ -15,135 +15,133 @@ from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
@pytest.mark.usefixtures("base_config", "cli_runner")
class TestDiffTransformer:
"""Tests for convert-diff-transformer CLI command"""
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
def test_cli_validation(self, cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(self, cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-diff-transformer", str(config_path)]
)
assert result.exit_code == 0
with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(self, tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs()
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
assert not debug_info
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs()
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
def test_conversion_cli_debug(self, tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
assert not debug_info
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
def test_conversion_cli_debug(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
def test_conversion_cli_reproduce(self, tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
def test_conversion_cli_repoduce_attentions(
self, tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
def test_conversion_cli_split_heads(
self, tmp_path: Path, base_config, attention: str
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()