From eb6611d55ff59d20c15253b30ea3e77be3e0b0db Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 24 Dec 2024 05:30:46 +0000 Subject: [PATCH] progress on modeling code --- .../integrations/convert_diff_transformer.py | 46 ++-- .../integrations/diff_transformer/README.md | 3 + .../integrations/diff_transformer/__init__.py | 13 +- .../integrations/diff_transformer/args.py | 3 + .../integrations/diff_transformer/convert.py | 135 ------------ .../diff_transformer/diff_attn.py | 36 +-- .../diff_transformer/modeling_diff_attn.py | 67 ++++-- src/axolotl/utils/models.py | 21 +- .../test_convert_diff_transformer.py | 208 +++++++++--------- 9 files changed, 216 insertions(+), 316 deletions(-) delete mode 100644 src/axolotl/integrations/diff_transformer/convert.py diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index db4b0df4d..28cc87bbd 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -26,34 +26,27 @@ LOG = logging.getLogger(__name__) def test_inference(model, tokenizer, prompt="The quick brown fox"): """Run test inference and return generation time""" - try: - inputs = tokenizer(prompt, return_tensors="pt") - inputs = { - k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() - } + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()} - start = time() - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=20, - num_beams=1, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - ) - elapsed = time() - start + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - LOG.info("Prompt: %s", prompt) - LOG.info("Generated: %s", generated_text) - LOG.info("Generation time: %.2fs", elapsed) + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) - return elapsed, generated_text - - except Exception as exc: - LOG.error("Inference failed: %s", str(exc)) - raise + return elapsed, generated_text def convert_diff_transformer(cfg, cli_args, config_path): @@ -89,7 +82,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) try: - LlamaDifferentialForCausalLM.from_llama( + model = LlamaDifferentialForCausalLM.from_llama( model, LlamaDifferentialConfig( **model.config.__dict__, @@ -98,6 +91,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): split_heads=cli_args.split_heads, ), ) + model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md index 14473f753..a683fdf1d 100644 --- a/src/axolotl/integrations/diff_transformer/README.md +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -7,4 +7,7 @@ plugins: - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin diff_attention: true +diff_attn_zero_init: false +diff_attn_sublayer_norm: true +diff_attn_split_heads: false ``` diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index 70459e026..461ede4fd 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -8,18 +8,7 @@ LOG = logging.getLogger(__name__) class DifferentialTransformerPlugin(BasePlugin): - """ - Plugin for differential transformer integration with Axolotl. - """ + """Plugin for differential transformer integration with Axolotl.""" def get_input_args(self): return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" - - def pre_model_load(self, cfg): - """Apply differential attention patch before model loading if enabled.""" - if cfg.diff_attention: - from axolotl.monkeypatch.attention.differential import ( - patch_llama_attention_classes, - ) - - patch_llama_attention_classes() diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py index 47c1fe110..332c0b4aa 100644 --- a/src/axolotl/integrations/diff_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -12,3 +12,6 @@ class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" diff_attention: Optional[bool] = None + diff_attn_zero_init: Optional[bool] = None + diff_attn_sublayer_norm: Optional[bool] = None + diff_attn_split_heads: Optional[bool] = None diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py deleted file mode 100644 index 298a0232e..000000000 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Differential attention conversion logic for a huggingface pre-trained model.""" -import logging -from typing import Union - -import torch -from torch import nn -from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaFlashAttention2, - LlamaSdpaAttention, -) - -from .diff_attn import ( - LlamaDifferentialAttention, - LlamaDifferentialFlashAttention2, - LlamaDifferentialSdpaAttention, -) - -logger = logging.getLogger(__name__) - -ATTENTION_MAPPING = { - LlamaAttention: LlamaDifferentialAttention, - LlamaSdpaAttention: LlamaDifferentialSdpaAttention, - LlamaFlashAttention2: LlamaDifferentialFlashAttention2, -} - - -def copy_attention_weights( - old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2], - new_attn: Union[ - LlamaDifferentialAttention, - LlamaDifferentialSdpaAttention, - LlamaDifferentialFlashAttention2, - ], - zero_init: bool = False, -) -> None: - """ - Copy weights from old attention layer to new differential attention layer. - Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence - to original attention mechanism. - """ - # For Q projection (Q1 and Q2) - new_q = torch.empty_like(new_attn.q_proj.weight.data) - new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1 - if zero_init: - new_q[new_attn.hidden_size :] = 0 - else: - nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1) - new_attn.q_proj.weight.data.copy_(new_q) - - # For K projection (K1 and K2) - old_kv_size = old_attn.k_proj.weight.data.size(0) - new_k = torch.empty_like(new_attn.k_proj.weight.data) - new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 - if zero_init: - new_k[old_kv_size:] = 0 - else: - nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1) - new_attn.k_proj.weight.data.copy_(new_k) - - # For V projection (single V) - new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) - - # Output projection remains the same - new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) - - # Zero out lambda parameters for exact equivalence - if zero_init: - nn.init.zeros_(new_attn.lambda_q1) - nn.init.zeros_(new_attn.lambda_k1) - nn.init.zeros_(new_attn.lambda_q2) - nn.init.zeros_(new_attn.lambda_k2) - nn.init.zeros_(new_attn.lambda_init) - - logger.debug( - "Copied positive attention weights from %s to %s", - type(old_attn).__name__, - type(new_attn).__name__, - ) - - -def convert_to_diff_attn( - model: PreTrainedModel, - zero_init: bool = False, - sublayer_norm: bool = True, - split_heads: bool = True, -) -> PreTrainedModel: - """Convert a pre-trained model's attention layers to differential attention""" - layer_idx = 0 - - # Set sublayer norm as config on the model. - model.config.sublayer_norm = sublayer_norm - model.config.split_heads = split_heads - - def convert_module(module): - nonlocal layer_idx - - # Iterate through module children, convert any attn layers to diff attn - for name, child in module.named_children(): - child_class_name = type(child).__name__ - - if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: - # Find matching attention class by name - for orig_class, diff_class in ATTENTION_MAPPING.items(): - if orig_class.__name__ == child_class_name: - attention_class = diff_class - break - - layer_type = type(child).__name__ - logger.info( - f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" - ) - - # Create new diff attn layer - new_attention = attention_class( - config=module.config if hasattr(module, "config") else model.config, - layer_idx=layer_idx, - ) - - # Copy weights from old attention to new attention - new_attention.to(child.q_proj.weight.device) - if not split_heads: - copy_attention_weights(child, new_attention, zero_init=zero_init) - - # Replace the layer - setattr(module, name, new_attention) - layer_idx += 1 - elif len(list(child.children())) > 0: - convert_module(child) - - convert_module(model) - logger.info(f"Converted {layer_idx} attention layers to differential attention") - - return model diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index cccb0adeb..5ae503464 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -56,8 +56,10 @@ class LlamaDifferentialAttentionBase(nn.Module): """Initialize configuration parameters.""" self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // config.num_attention_heads self.base_num_heads = config.num_attention_heads self.base_num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -66,15 +68,15 @@ class LlamaDifferentialAttentionBase(nn.Module): if config.split_heads: # Split heads mode - single projections - self.head_dim = config.hidden_size // config.num_attention_heads # NOTE: This rounds down `base_num_heads / 2` as opposed to the original # implementation, which asserts `self.base_num_heads` is even self.heads_per_component = self.base_num_heads // 2 + self.kv_heads_per_component = self.base_num_kv_heads // 2 self.value_head_dim = 2 * self.head_dim else: # Double projection mode - self.head_dim = config.hidden_size // config.num_attention_heads self.heads_per_component = self.base_num_heads + self.kv_heads_per_component = self.base_num_kv_heads self.value_head_dim = self.head_dim def _init_projections(self): @@ -90,14 +92,22 @@ class LlamaDifferentialAttentionBase(nn.Module): self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 ) - self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False) + self.q_proj = nn.Linear( + self.hidden_size, q_out_dim, bias=self.config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, k_out_dim, bias=self.config.attention_bias + ) self.v_proj = nn.Linear( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, + bias=self.config.attention_bias, + ) + self.o_proj = nn.Linear( + self.base_num_heads * self.head_dim, + self.hidden_size, + bias=self.config.attention_bias, ) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) def _init_differential_params(self): """Initialize differential attention parameters.""" @@ -145,13 +155,13 @@ class LlamaDifferentialAttentionBase(nn.Module): q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( 1, 2 ) - k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( 1, 2 ) - k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( 1, 2 ) - v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose( + v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose( 1, 2 ) @@ -184,10 +194,10 @@ class LlamaDifferentialAttentionBase(nn.Module): k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k1, k2 = k.unbind(dim=1) - # Repeat KV heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + # Repeat KV heads to match number of query heads + k1 = repeat_kv(k1, self.num_key_value_groups) + k2 = repeat_kv(k2, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) return k1, k2, v diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index a3d31382d..b84dfcd16 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -56,19 +56,16 @@ class LlamaDifferentialModel(LlamaModel): def __init__(self, config): super().__init__(config) + # Replace standard attention with differential attention in each layer - for layer in self.layers: + for idx, layer in enumerate(self.layers): attn_impl = config._attn_implementation or "eager" if attn_impl == "eager": - layer.self_attn = LlamaDifferentialAttention(config, layer.layer_idx) + layer.self_attn = LlamaDifferentialAttention(config, idx) elif attn_impl == "sdpa": - layer.self_attn = LlamaDifferentialSdpaAttention( - config, layer.layer_idx - ) + layer.self_attn = LlamaDifferentialSdpaAttention(config, idx) elif attn_impl == "flash_attention_2": - layer.self_attn = LlamaDifferentialFlashAttention2( - config, layer.layer_idx - ) + layer.self_attn = LlamaDifferentialFlashAttention2(config, idx) @classmethod def from_llama( @@ -78,7 +75,21 @@ class LlamaDifferentialModel(LlamaModel): if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + new_model = cls(config) + # Copy all weights except attention new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) new_model.norm.load_state_dict(model.norm.state_dict()) @@ -97,34 +108,28 @@ class LlamaDifferentialModel(LlamaModel): new_layer.self_attn.v_proj.load_state_dict( old_layer.self_attn.v_proj.state_dict() ) + print(old_layer.self_attn.o_proj.weight.shape) new_layer.self_attn.o_proj.load_state_dict( old_layer.self_attn.o_proj.state_dict() ) - if config.split_heads: - new_layer.self_attn.q_proj.weight.data.copy_( + # Get the original projection sizes + old_q_size = old_layer.self_attn.q_proj.weight.size(0) + old_k_size = old_layer.self_attn.k_proj.weight.size(0) + + if not config.split_heads: + new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_( old_layer.self_attn.q_proj.weight.data ) - new_layer.self_attn.k_proj.weight.data.copy_( - old_layer.self_attn.k_proj.weight.data - ) - else: - new_layer.self_attn.q_proj.weight.data[: config.hidden_size].copy_( - old_layer.self_attn.q_proj.weight.data - ) - new_layer.self_attn.k_proj.weight.data[: config.hidden_size].copy_( + new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_( old_layer.self_attn.k_proj.weight.data ) if config.zero_init: # Zero out components as needed with torch.no_grad(): - new_layer.self_attn.q_proj.weight.data[ - config.hidden_size : - ].zero_() - new_layer.self_attn.k_proj.weight.data[ - config.hidden_size : - ].zero_() + new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() + new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_() new_layer.self_attn.lambda_q1.zero_() new_layer.self_attn.lambda_k1.zero_() new_layer.self_attn.lambda_q2.zero_() @@ -149,7 +154,21 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM): if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + new_model = cls(config) new_model.model = LlamaDifferentialModel.from_llama(model.model, config) new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + return new_model diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 37cbc0871..2c4d2513d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,11 +710,30 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ + # if self.cfg.flash_attention: + # if not self.cfg.sample_packing and self.cfg.s2_attention: + # pass + + # self.model_kwargs["attn_implementation"] = "flash_attention_2" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "flash_attention_2" + # ) + # elif self.cfg.sdp_attention: + # self.model_kwargs["attn_implementation"] = "sdpa" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "sdpa" + # ) + # elif self.cfg.eager_attention: + # self.model_kwargs["attn_implementation"] = "eager" + # self.model_config._attn_implementation = ( # pylint: disable=protected-access + # "eager" + # ) + if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.differentiaion: + if self.cfg.diff_attention: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index 02939ee1c..e616a8ef1 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -15,135 +15,133 @@ from axolotl.cli.main import cli from axolotl.common.cli import ConvertDiffTransformerCliArgs -@pytest.mark.usefixtures("base_config", "cli_runner") -class TestDiffTransformer: - """Tests for convert-diff-transformer CLI command""" +def test_cli_validation(cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output - def test_cli_validation(self, cli_runner): - # Test missing config file - result = cli_runner.invoke(cli, ["convert-diff-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output + # Test non-existent config file + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output - # Test non-existent config file - result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - def test_basic_execution(self, cli_runner, tmp_path: Path, base_config): - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) +def test_basic_execution(cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) - with patch( - "axolotl.cli.integrations.convert_diff_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-diff-transformer", str(config_path)] - ) - assert result.exit_code == 0 + with patch( + "axolotl.cli.integrations.convert_diff_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) + assert result.exit_code == 0 - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - def test_conversion_cli_basic(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) +def test_conversion_cli_basic(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs() - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) - assert not debug_info - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - def test_conversion_cli_debug(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) +def test_conversion_cli_debug(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) - assert not debug_info["generations_match"] - assert not debug_info["match_expected"] - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) - def test_conversion_cli_reproduce(self, tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() +def test_conversion_cli_reproduce(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) - @pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False ) - def test_conversion_cli_repoduce_attentions( - self, tmp_path: Path, base_config, attention: Optional[str] - ): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_repoduce_attentions( + tmp_path: Path, base_config, attention: Optional[str] +): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True - @pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False ) - def test_conversion_cli_split_heads( - self, tmp_path: Path, base_config, attention: str - ): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - assert debug_info["generations_match"] is False - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()