progress on modeling code
This commit is contained in:
@@ -26,34 +26,27 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
"""Run test inference and return generation time"""
|
||||
try:
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {
|
||||
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
||||
}
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
|
||||
|
||||
start = time()
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
elapsed = time() - start
|
||||
start = time()
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
elapsed = time() - start
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
LOG.info("Prompt: %s", prompt)
|
||||
LOG.info("Generated: %s", generated_text)
|
||||
LOG.info("Generation time: %.2fs", elapsed)
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
LOG.info("Prompt: %s", prompt)
|
||||
LOG.info("Generated: %s", generated_text)
|
||||
LOG.info("Generation time: %.2fs", elapsed)
|
||||
|
||||
return elapsed, generated_text
|
||||
|
||||
except Exception as exc:
|
||||
LOG.error("Inference failed: %s", str(exc))
|
||||
raise
|
||||
return elapsed, generated_text
|
||||
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
@@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
+ Fore.RESET
|
||||
)
|
||||
try:
|
||||
LlamaDifferentialForCausalLM.from_llama(
|
||||
model = LlamaDifferentialForCausalLM.from_llama(
|
||||
model,
|
||||
LlamaDifferentialConfig(
|
||||
**model.config.__dict__,
|
||||
@@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
split_heads=cli_args.split_heads,
|
||||
),
|
||||
)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
@@ -7,4 +7,7 @@ plugins:
|
||||
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||
|
||||
diff_attention: true
|
||||
diff_attn_zero_init: false
|
||||
diff_attn_sublayer_norm: true
|
||||
diff_attn_split_heads: false
|
||||
```
|
||||
|
||||
@@ -8,18 +8,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for differential transformer integration with Axolotl.
|
||||
"""
|
||||
"""Plugin for differential transformer integration with Axolotl."""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply differential attention patch before model loading if enabled."""
|
||||
if cfg.diff_attention:
|
||||
from axolotl.monkeypatch.attention.differential import (
|
||||
patch_llama_attention_classes,
|
||||
)
|
||||
|
||||
patch_llama_attention_classes()
|
||||
|
||||
@@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel):
|
||||
"""Input args for differential transformer."""
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
diff_attn_zero_init: Optional[bool] = None
|
||||
diff_attn_sublayer_norm: Optional[bool] = None
|
||||
diff_attn_split_heads: Optional[bool] = None
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Differential attention conversion logic for a huggingface pre-trained model."""
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from .diff_attn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ATTENTION_MAPPING = {
|
||||
LlamaAttention: LlamaDifferentialAttention,
|
||||
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
|
||||
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
def copy_attention_weights(
|
||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
|
||||
new_attn: Union[
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
],
|
||||
zero_init: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Copy weights from old attention layer to new differential attention layer.
|
||||
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
|
||||
to original attention mechanism.
|
||||
"""
|
||||
# For Q projection (Q1 and Q2)
|
||||
new_q = torch.empty_like(new_attn.q_proj.weight.data)
|
||||
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
|
||||
if zero_init:
|
||||
new_q[new_attn.hidden_size :] = 0
|
||||
else:
|
||||
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
|
||||
new_attn.q_proj.weight.data.copy_(new_q)
|
||||
|
||||
# For K projection (K1 and K2)
|
||||
old_kv_size = old_attn.k_proj.weight.data.size(0)
|
||||
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
||||
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
||||
if zero_init:
|
||||
new_k[old_kv_size:] = 0
|
||||
else:
|
||||
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
|
||||
new_attn.k_proj.weight.data.copy_(new_k)
|
||||
|
||||
# For V projection (single V)
|
||||
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
|
||||
|
||||
# Output projection remains the same
|
||||
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
|
||||
|
||||
# Zero out lambda parameters for exact equivalence
|
||||
if zero_init:
|
||||
nn.init.zeros_(new_attn.lambda_q1)
|
||||
nn.init.zeros_(new_attn.lambda_k1)
|
||||
nn.init.zeros_(new_attn.lambda_q2)
|
||||
nn.init.zeros_(new_attn.lambda_k2)
|
||||
nn.init.zeros_(new_attn.lambda_init)
|
||||
|
||||
logger.debug(
|
||||
"Copied positive attention weights from %s to %s",
|
||||
type(old_attn).__name__,
|
||||
type(new_attn).__name__,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_diff_attn(
|
||||
model: PreTrainedModel,
|
||||
zero_init: bool = False,
|
||||
sublayer_norm: bool = True,
|
||||
split_heads: bool = True,
|
||||
) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
layer_idx = 0
|
||||
|
||||
# Set sublayer norm as config on the model.
|
||||
model.config.sublayer_norm = sublayer_norm
|
||||
model.config.split_heads = split_heads
|
||||
|
||||
def convert_module(module):
|
||||
nonlocal layer_idx
|
||||
|
||||
# Iterate through module children, convert any attn layers to diff attn
|
||||
for name, child in module.named_children():
|
||||
child_class_name = type(child).__name__
|
||||
|
||||
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
|
||||
# Find matching attention class by name
|
||||
for orig_class, diff_class in ATTENTION_MAPPING.items():
|
||||
if orig_class.__name__ == child_class_name:
|
||||
attention_class = diff_class
|
||||
break
|
||||
|
||||
layer_type = type(child).__name__
|
||||
logger.info(
|
||||
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
|
||||
)
|
||||
|
||||
# Create new diff attn layer
|
||||
new_attention = attention_class(
|
||||
config=module.config if hasattr(module, "config") else model.config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Copy weights from old attention to new attention
|
||||
new_attention.to(child.q_proj.weight.device)
|
||||
if not split_heads:
|
||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||
|
||||
# Replace the layer
|
||||
setattr(module, name, new_attention)
|
||||
layer_idx += 1
|
||||
elif len(list(child.children())) > 0:
|
||||
convert_module(child)
|
||||
|
||||
convert_module(model)
|
||||
logger.info(f"Converted {layer_idx} attention layers to differential attention")
|
||||
|
||||
return model
|
||||
@@ -56,8 +56,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
"""Initialize configuration parameters."""
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
@@ -66,15 +68,15 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
|
||||
if config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||
# implementation, which asserts `self.base_num_heads` is even
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
self.kv_heads_per_component = self.base_num_kv_heads // 2
|
||||
self.value_head_dim = 2 * self.head_dim
|
||||
else:
|
||||
# Double projection mode
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.heads_per_component = self.base_num_heads
|
||||
self.kv_heads_per_component = self.base_num_kv_heads
|
||||
self.value_head_dim = self.head_dim
|
||||
|
||||
def _init_projections(self):
|
||||
@@ -90,14 +92,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, q_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, k_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
bias=False,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.base_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
|
||||
def _init_differential_params(self):
|
||||
"""Initialize differential attention parameters."""
|
||||
@@ -145,13 +155,13 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
@@ -184,10 +194,10 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
||||
k1, k2 = k.unbind(dim=1)
|
||||
|
||||
# Repeat KV heads
|
||||
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
|
||||
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
|
||||
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
|
||||
# Repeat KV heads to match number of query heads
|
||||
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
|
||||
return k1, k2, v
|
||||
|
||||
|
||||
@@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Replace standard attention with differential attention in each layer
|
||||
for layer in self.layers:
|
||||
for idx, layer in enumerate(self.layers):
|
||||
attn_impl = config._attn_implementation or "eager"
|
||||
if attn_impl == "eager":
|
||||
layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx)
|
||||
layer.self_attn = LlamaDifferentialAttention(config, idx)
|
||||
elif attn_impl == "sdpa":
|
||||
layer.self_attn = LlamaDifferentialSdpaAttention(
|
||||
config, layer.layer_idx
|
||||
)
|
||||
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
|
||||
elif attn_impl == "flash_attention_2":
|
||||
layer.self_attn = LlamaDifferentialFlashAttention2(
|
||||
config, layer.layer_idx
|
||||
)
|
||||
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
@@ -78,7 +75,21 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
|
||||
# Validate head counts if using split heads mode
|
||||
if config.split_heads:
|
||||
if config.num_attention_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
if config.num_key_value_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
|
||||
new_model = cls(config)
|
||||
|
||||
# Copy all weights except attention
|
||||
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
|
||||
new_model.norm.load_state_dict(model.norm.state_dict())
|
||||
@@ -97,34 +108,28 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.v_proj.load_state_dict(
|
||||
old_layer.self_attn.v_proj.state_dict()
|
||||
)
|
||||
print(old_layer.self_attn.o_proj.weight.shape)
|
||||
new_layer.self_attn.o_proj.load_state_dict(
|
||||
old_layer.self_attn.o_proj.state_dict()
|
||||
)
|
||||
|
||||
if config.split_heads:
|
||||
new_layer.self_attn.q_proj.weight.data.copy_(
|
||||
# Get the original projection sizes
|
||||
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
|
||||
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
|
||||
|
||||
if not config.split_heads:
|
||||
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data.copy_(
|
||||
old_layer.self_attn.k_proj.weight.data
|
||||
)
|
||||
else:
|
||||
new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_(
|
||||
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
|
||||
old_layer.self_attn.k_proj.weight.data
|
||||
)
|
||||
|
||||
if config.zero_init:
|
||||
# Zero out components as needed
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[
|
||||
config.hidden_size :
|
||||
].zero_()
|
||||
new_layer.self_attn.k_proj.weight.data[
|
||||
config.hidden_size :
|
||||
].zero_()
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
|
||||
new_layer.self_attn.lambda_q1.zero_()
|
||||
new_layer.self_attn.lambda_k1.zero_()
|
||||
new_layer.self_attn.lambda_q2.zero_()
|
||||
@@ -149,7 +154,21 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
if config is None:
|
||||
config = LlamaDifferentialConfig(**model.config.__dict__)
|
||||
|
||||
# Validate head counts if using split heads mode
|
||||
if config.split_heads:
|
||||
if config.num_attention_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of attention heads ({config.num_attention_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
if config.num_key_value_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
|
||||
"when using split_heads=True"
|
||||
)
|
||||
|
||||
new_model = cls(config)
|
||||
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
|
||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
||||
|
||||
return new_model
|
||||
|
||||
@@ -710,11 +710,30 @@ class ModelLoader:
|
||||
"""
|
||||
sample packing uses custom FA2 patch
|
||||
"""
|
||||
# if self.cfg.flash_attention:
|
||||
# if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
# pass
|
||||
|
||||
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "flash_attention_2"
|
||||
# )
|
||||
# elif self.cfg.sdp_attention:
|
||||
# self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "sdpa"
|
||||
# )
|
||||
# elif self.cfg.eager_attention:
|
||||
# self.model_kwargs["attn_implementation"] = "eager"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "eager"
|
||||
# )
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
|
||||
if self.cfg.differentiaion:
|
||||
if self.cfg.diff_attention:
|
||||
self.model_kwargs[
|
||||
"attn_implementation"
|
||||
] = "differential_flash_attention_2"
|
||||
|
||||
@@ -15,135 +15,133 @@ from axolotl.cli.main import cli
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("base_config", "cli_runner")
|
||||
class TestDiffTransformer:
|
||||
"""Tests for convert-diff-transformer CLI command"""
|
||||
def test_cli_validation(cli_runner):
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
def test_cli_validation(self, cli_runner):
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def test_basic_execution(self, cli_runner, tmp_path: Path, base_config):
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-diff-transformer", str(config_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
|
||||
def test_conversion_cli_basic(self, tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs()
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
assert not debug_info
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs()
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
def test_conversion_cli_debug(self, tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
assert not debug_info
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
assert not debug_info["match_expected"]
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
def test_conversion_cli_reproduce(self, tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
assert not debug_info["generations_match"]
|
||||
assert not debug_info["match_expected"]
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
def test_conversion_cli_repoduce_attentions(
|
||||
self, tmp_path: Path, base_config, attention: Optional[str]
|
||||
):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
)
|
||||
def test_conversion_cli_repoduce_attentions(
|
||||
tmp_path: Path, base_config, attention: Optional[str]
|
||||
):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
def test_conversion_cli_split_heads(
|
||||
self, tmp_path: Path, base_config, attention: str
|
||||
):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is False
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
)
|
||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is False
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
Reference in New Issue
Block a user