diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 116a60480..a8c7e5942 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -13,7 +13,7 @@ from dotenv import load_dotenv from transformers import HfArgumentParser from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention LOG = logging.getLogger("axolotl.cli.convert_attention") @@ -67,21 +67,23 @@ def convert_diff_transformer(cfg, cli_args, config_path): ) # Test original model - LOG.info("Testing original model...") - orig_time, orig_text = test_inference(model, tokenizer) + if cli_args.debug: + LOG.info("Testing original model...") + orig_time, orig_text = test_inference(model, tokenizer) # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention(model) + model = convert_to_diff_attention(model, cli_args.zero_init) model.to(model.device) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise # Test converted model - LOG.info("Testing converted model...") - conv_time, conv_text = test_inference(model, tokenizer) + if cli_args.debug: + LOG.info("Testing converted model...") + conv_time, conv_text = test_inference(model, tokenizer) # Save if requested if cfg.output_dir: @@ -106,30 +108,65 @@ def convert_diff_transformer(cfg, cli_args, config_path): LOG.info("Not saving converted model to disk") LOG.info("Pass --output-dir path/to/save to save model") - LOG.info( - Fore.GREEN - + "Conversion successful!\n" - + f"Original generation time: {orig_time:.2f}s\n" - + f"Converted generation time: {conv_time:.2f}s" - + Fore.RESET - ) - - if orig_text == conv_text: + if cli_args.debug: LOG.info( Fore.GREEN - + "Generations match!\n" - + f"Model generation: {orig_text}\n" - + Fore.RESET - ) - else: - LOG.info( - Fore.RED - + "Generations do not match.\n" - + f"Original generation: {orig_text}\n" - + f"Converted generation: {conv_text}\n" + + "Conversion successful!\n" + + f"Original generation time: {orig_time:.2f}s\n" + + f"Converted generation time: {conv_time:.2f}s" + Fore.RESET ) + if orig_text == conv_text: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + else: + if cli_args.zero_init: + LOG.info( + Fore.RED + + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + else: + LOG.info( + Fore.YELLOW + + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + + "However, this is expected since --zero-init was not passed." + + Fore.RESET + ) except Exception as exc: LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) raise @@ -139,7 +176,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): print_axolotl_text_art() cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(TrainerCliArgs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) convert_diff_transformer(cfg, cli_args, config) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 360af4810..c37aa5484 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,7 @@ from axolotl.cli.utils import ( build_command, fetch_from_github, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -242,7 +242,7 @@ def merge_lora( @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) def convert_diff_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 02ad9201b..bdab7c272 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -18,7 +18,7 @@ LOG = logging.getLogger("axolotl.common.cli") @dataclass class PreprocessCliArgs: """ - dataclass representing arguments for preprocessing only + dataclass with arguments for preprocessing only """ debug: bool = field(default=False) @@ -31,7 +31,7 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: """ - dataclass representing the various non-training arguments + dataclass with various non-training arguments """ debug: bool = field(default=False) @@ -46,7 +46,7 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: """ - dataclass representing the various evaluation arguments + dataclass with various evaluation arguments """ debug: bool = field(default=False) @@ -54,6 +54,16 @@ class EvaluateCliArgs: debug_num_examples: int = field(default=0) +@dataclass +class ConvertDiffTransformerCliArgs: + """ + dataclass with arguments for convert-diff-transformer CLI + """ + + debug: bool = field(default=False) + zero_init: bool = field(default=False) + + def load_model_and_tokenizer( *, cfg: DictDefault, diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 584c19d5f..24bc07cf7 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) def copy_attention_weights( old_attn: Union[LlamaAttention, LlamaSdpaAttention], new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], - zero_init: bool = True, + zero_init: bool = False, ) -> None: """ Copy weights from old attention layer to new differential attention layer. @@ -68,7 +68,9 @@ def copy_attention_weights( ) -def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: +def convert_to_diff_attention( + model: PreTrainedModel, zero_init: bool +) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" attention_patterns = ( LlamaAttention, @@ -78,9 +80,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: ) layer_idx = 0 - # Get model dtype from existing weights - model_dtype = next(model.parameters()).dtype - def convert_module(module): nonlocal layer_idx @@ -103,11 +102,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, layer_idx=layer_idx, - dtype=model_dtype, ) # Copy weights from old attention to new attention - copy_attention_weights(child, new_attention) + copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer setattr(module, name, new_attention) diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 6d3bc7589..ace9c58de 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -60,11 +60,10 @@ class LlamaDifferentialAttention(nn.Module): self, config: Any, layer_idx: int, - dtype: torch.dtype, ): super().__init__() - # Base model dimensions + # Base model config self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.base_num_heads = config.num_attention_heads @@ -77,6 +76,8 @@ class LlamaDifferentialAttention(nn.Module): self.rope_theta = config.rope_theta self.is_causal = True + dtype = getattr(config, "torch_dtype", torch.float32) + # For Q1 and Q2 self.q_proj = nn.Linear( self.hidden_size, diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py index 7ff35633c..14117bf63 100644 --- a/src/axolotl/integrations/diff_transformer/patches.py +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -1,4 +1,5 @@ """Patches related to differential transformers implementation.""" + from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7d8599c3c..c8e08468f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,13 +710,6 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ - print( - self.cfg.flash_attention, - self.cfg.sdp_attention, - self.cfg.eager_attention, - self.cfg.diff_attention, - ) - if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass @@ -761,9 +754,6 @@ class ModelLoader: "differential_eager" ) - if "attn_implementation" in self.model_kwargs: - print(self.model_kwargs["attn_implementation"]) - if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True