progress on modeling code

This commit is contained in:
Dan Saunders
2024-12-24 05:30:46 +00:00
parent 4ff3328e66
commit eb6611d55f
9 changed files with 216 additions and 316 deletions

View File

@@ -26,34 +26,27 @@ LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"): def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time""" """Run test inference and return generation time"""
try: inputs = tokenizer(prompt, return_tensors="pt")
inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
start = time() start = time()
with torch.no_grad(): with torch.no_grad():
outputs = model.generate( outputs = model.generate(
**inputs, **inputs,
max_new_tokens=20, max_new_tokens=20,
num_beams=1, num_beams=1,
do_sample=False, do_sample=False,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
use_cache=False, use_cache=False,
) )
elapsed = time() - start elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt) LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text) LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed) LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
def convert_diff_transformer(cfg, cli_args, config_path): def convert_diff_transformer(cfg, cli_args, config_path):
@@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET + Fore.RESET
) )
try: try:
LlamaDifferentialForCausalLM.from_llama( model = LlamaDifferentialForCausalLM.from_llama(
model, model,
LlamaDifferentialConfig( LlamaDifferentialConfig(
**model.config.__dict__, **model.config.__dict__,
@@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
split_heads=cli_args.split_heads, split_heads=cli_args.split_heads,
), ),
) )
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc: except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise raise

View File

@@ -7,4 +7,7 @@ plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true diff_attention: true
diff_attn_zero_init: false
diff_attn_sublayer_norm: true
diff_attn_split_heads: false
``` ```

View File

@@ -8,18 +8,7 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin): class DifferentialTransformerPlugin(BasePlugin):
""" """Plugin for differential transformer integration with Axolotl."""
Plugin for differential transformer integration with Axolotl.
"""
def get_input_args(self): def get_input_args(self):
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()

View File

@@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer.""" """Input args for differential transformer."""
diff_attention: Optional[bool] = None diff_attention: Optional[bool] = None
diff_attn_zero_init: Optional[bool] = None
diff_attn_sublayer_norm: Optional[bool] = None
diff_attn_split_heads: Optional[bool] = None

View File

@@ -1,135 +0,0 @@
"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0)
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
nn.init.zeros_(new_attn.lambda_init)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
child_class_name = type(child).__name__
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
# Find matching attention class by name
for orig_class, diff_class in ATTENTION_MAPPING.items():
if orig_class.__name__ == child_class_name:
attention_class = diff_class
break
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(model)
logger.info(f"Converted {layer_idx} attention layers to differential attention")
return model

View File

@@ -56,8 +56,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
"""Initialize configuration parameters.""" """Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
self.base_num_heads = config.num_attention_heads self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads self.base_num_kv_heads = config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
@@ -66,15 +68,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
if config.split_heads: if config.split_heads:
# Split heads mode - single projections # Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original # NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even # implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2 self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim self.value_head_dim = 2 * self.head_dim
else: else:
# Double projection mode # Double projection mode
self.head_dim = config.hidden_size // config.num_attention_heads
self.heads_per_component = self.base_num_heads self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim self.value_head_dim = self.head_dim
def _init_projections(self): def _init_projections(self):
@@ -90,14 +92,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
) )
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False) self.q_proj = nn.Linear(
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False) self.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear( self.v_proj = nn.Linear(
self.hidden_size, self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads, self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False, bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.hidden_size,
bias=self.config.attention_bias,
) )
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_differential_params(self): def _init_differential_params(self):
"""Initialize differential attention parameters.""" """Initialize differential attention parameters."""
@@ -145,13 +155,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2 1, 2
) )
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2 1, 2
) )
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2 1, 2
) )
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose( v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
1, 2 1, 2
) )
@@ -184,10 +194,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1) k1, k2 = k.unbind(dim=1)
# Repeat KV heads # Repeat KV heads to match number of query heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) v = repeat_kv(v, self.num_key_value_groups)
return k1, k2, v return k1, k2, v

View File

@@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
# Replace standard attention with differential attention in each layer # Replace standard attention with differential attention in each layer
for layer in self.layers: for idx, layer in enumerate(self.layers):
attn_impl = config._attn_implementation or "eager" attn_impl = config._attn_implementation or "eager"
if attn_impl == "eager": if attn_impl == "eager":
layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx) layer.self_attn = LlamaDifferentialAttention(config, idx)
elif attn_impl == "sdpa": elif attn_impl == "sdpa":
layer.self_attn = LlamaDifferentialSdpaAttention( layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
config, layer.layer_idx
)
elif attn_impl == "flash_attention_2": elif attn_impl == "flash_attention_2":
layer.self_attn = LlamaDifferentialFlashAttention2( layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
config, layer.layer_idx
)
@classmethod @classmethod
def from_llama( def from_llama(
@@ -78,7 +75,21 @@ class LlamaDifferentialModel(LlamaModel):
if config is None: if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__) config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config) new_model = cls(config)
# Copy all weights except attention # Copy all weights except attention
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict()) new_model.norm.load_state_dict(model.norm.state_dict())
@@ -97,34 +108,28 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.v_proj.load_state_dict( new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict() old_layer.self_attn.v_proj.state_dict()
) )
print(old_layer.self_attn.o_proj.weight.shape)
new_layer.self_attn.o_proj.load_state_dict( new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict() old_layer.self_attn.o_proj.state_dict()
) )
if config.split_heads: # Get the original projection sizes
new_layer.self_attn.q_proj.weight.data.copy_( old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data old_layer.self_attn.q_proj.weight.data
) )
new_layer.self_attn.k_proj.weight.data.copy_( new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
else:
new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_(
old_layer.self_attn.k_proj.weight.data old_layer.self_attn.k_proj.weight.data
) )
if config.zero_init: if config.zero_init:
# Zero out components as needed # Zero out components as needed
with torch.no_grad(): with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[ new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
config.hidden_size : new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
].zero_()
new_layer.self_attn.k_proj.weight.data[
config.hidden_size :
].zero_()
new_layer.self_attn.lambda_q1.zero_() new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_() new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_() new_layer.self_attn.lambda_q2.zero_()
@@ -149,7 +154,21 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
if config is None: if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__) config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config) new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config) new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict()) new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model return new_model

View File

@@ -710,11 +710,30 @@ class ModelLoader:
""" """
sample packing uses custom FA2 patch sample packing uses custom FA2 patch
""" """
# if self.cfg.flash_attention:
# if not self.cfg.sample_packing and self.cfg.s2_attention:
# pass
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "flash_attention_2"
# )
# elif self.cfg.sdp_attention:
# self.model_kwargs["attn_implementation"] = "sdpa"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "sdpa"
# )
# elif self.cfg.eager_attention:
# self.model_kwargs["attn_implementation"] = "eager"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "eager"
# )
if self.cfg.flash_attention: if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
if self.cfg.differentiaion: if self.cfg.diff_attention:
self.model_kwargs[ self.model_kwargs[
"attn_implementation" "attn_implementation"
] = "differential_flash_attention_2" ] = "differential_flash_attention_2"

View File

@@ -15,135 +15,133 @@ from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs from axolotl.common.cli import ConvertDiffTransformerCliArgs
@pytest.mark.usefixtures("base_config", "cli_runner") def test_cli_validation(cli_runner):
class TestDiffTransformer: # Test missing config file
"""Tests for convert-diff-transformer CLI command""" result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
def test_cli_validation(self, cli_runner): # Test non-existent config file
# Test missing config file result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
result = cli_runner.invoke(cli, ["convert-diff-transformer"]) assert result.exit_code != 0
assert result.exit_code != 0 assert "Error: Invalid value for 'CONFIG'" in result.output
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(self, cli_runner, tmp_path: Path, base_config): def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml" config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file: with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file) yaml.dump(base_config, file)
with patch( with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli" "axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli: ) as mock_do_cli:
result = cli_runner.invoke( result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
cli, ["convert-diff-transformer", str(config_path)] assert result.exit_code == 0
)
assert result.exit_code == 0
mock_do_cli.assert_called_once() mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path) assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(self, tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml" def test_conversion_cli_basic(tmp_path: Path, base_config):
with open(config_path, "w", encoding="utf-8") as file: output_dir = tmp_path / "converted"
yaml.dump(base_config, file) base_config["output_dir"] = str(output_dir)
cfg = load_cfg(str(config_path)) config_path = tmp_path / "config.yml"
cli_args = ConvertDiffTransformerCliArgs() with open(config_path, "w", encoding="utf-8") as file:
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) yaml.dump(base_config, file)
assert not debug_info cfg = load_cfg(str(config_path))
assert (output_dir / "model.safetensors").exists() cli_args = ConvertDiffTransformerCliArgs()
assert (output_dir / "config.json").exists() _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_debug(self, tmp_path: Path, base_config): assert not debug_info
output_dir = tmp_path / "converted" assert (output_dir / "model.safetensors").exists()
base_config["output_dir"] = str(output_dir) assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path)) def test_conversion_cli_debug(tmp_path: Path, base_config):
cli_args = ConvertDiffTransformerCliArgs(debug=True) output_dir = tmp_path / "converted"
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) base_config["output_dir"] = str(output_dir)
assert not debug_info["generations_match"] config_path = tmp_path / "config.yml"
assert not debug_info["match_expected"] with open(config_path, "w", encoding="utf-8") as file:
assert (output_dir / "model.safetensors").exists() yaml.dump(base_config, file)
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_reproduce(self, tmp_path: Path, base_config): cfg = load_cfg(str(config_path))
output_dir = tmp_path / "converted" cli_args = ConvertDiffTransformerCliArgs(debug=True)
base_config["output_dir"] = str(output_dir) _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
config_path = tmp_path / "config.yml" assert not debug_info["generations_match"]
with open(config_path, "w", encoding="utf-8") as file: assert not debug_info["match_expected"]
yaml.dump(base_config, file) assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True def test_conversion_cli_reproduce(tmp_path: Path, base_config):
assert (output_dir / "model.safetensors").exists() output_dir = tmp_path / "converted"
assert (output_dir / "config.json").exists() base_config["output_dir"] = str(output_dir)
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize( config_path = tmp_path / "config.yml"
"attention", ["eager_attention", "sdp_attention", "flash_attention"] with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
) )
def test_conversion_cli_repoduce_attentions( _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
self, tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml" assert debug_info["generations_match"] is True
with open(config_path, "w", encoding="utf-8") as file: assert (output_dir / "model.safetensors").exists()
yaml.dump(base_config, file) assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True @pytest.mark.parametrize(
assert (output_dir / "model.safetensors").exists() "attention", ["eager_attention", "sdp_attention", "flash_attention"]
assert (output_dir / "config.json").exists() )
assert (output_dir / "axolotl_config.yml").exists() def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
@pytest.mark.parametrize( config_path = tmp_path / "config.yml"
"attention", ["eager_attention", "sdp_attention", "flash_attention"] with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
) )
def test_conversion_cli_split_heads( _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
self, tmp_path: Path, base_config, attention: str
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml" assert debug_info["generations_match"] is True
with open(config_path, "w", encoding="utf-8") as file: assert (output_dir / "model.safetensors").exists()
yaml.dump(base_config, file) assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False @pytest.mark.parametrize(
assert (output_dir / "model.safetensors").exists() "attention", ["eager_attention", "sdp_attention", "flash_attention"]
assert (output_dir / "config.json").exists() )
assert (output_dir / "axolotl_config.yml").exists() def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()