From 503c4e9ffa03b8a87c5e85a85d5d94dbb23ccd05 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 04:43:08 +0000 Subject: [PATCH] fix model save / load logic --- .../integrations/convert_diff_transformer.py | 2 +- src/axolotl/common/cli.py | 4 +-- src/axolotl/evaluate.py | 25 +++++++------------ .../integrations/diff_transformer/convert.py | 3 ++- .../diff_transformer/multihead_diffattn.py | 16 +++++++++--- .../integrations/diff_transformer/patches.py | 5 +--- src/axolotl/utils/models.py | 11 +++++--- 7 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index a8c7e5942..1cbf619c8 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -75,7 +75,7 @@ def convert_diff_transformer(cfg, cli_args, config_path): LOG.info("Converting to differential attention...") try: model = convert_to_diff_attention(model, cli_args.zero_init) - model.to(model.device) + model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index bdab7c272..2b25b7f39 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -4,7 +4,7 @@ shared module for cli specific things import logging from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging @@ -67,7 +67,7 @@ class ConvertDiffTransformerCliArgs: def load_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs], ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc..bc1799960 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,13 +9,13 @@ from typing import Dict, Optional import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.models import load_processor +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") @@ -63,7 +63,7 @@ def evaluate_dataset( def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta ) -> Dict[str, float]: """ Evaluate a model on training and validation datasets @@ -83,12 +83,11 @@ def evaluate( # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() - # Load tokenizer - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) - tokenizer = load_tokenizer(cfg) + # Load model + LOG.debug("loading model for evaluation...") + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model = model.to(cfg.device, dtype=cfg.torch_dtype) # Load processor for multimodal models if needed processor = None @@ -100,12 +99,6 @@ def evaluate( eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps - # Load model - LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) - # Set up trainer trainer = setup_trainer( cfg, diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 24bc07cf7..bd688fadb 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -59,7 +59,7 @@ def copy_attention_weights( nn.init.zeros_(new_attn.lambda_k1) nn.init.zeros_(new_attn.lambda_q2) nn.init.zeros_(new_attn.lambda_k2) - new_attn.lambda_init = 0.0 + nn.init.zeros_(new_attn.lambda_init) logger.debug( "Copied positive attention weights from %s to %s", @@ -105,6 +105,7 @@ def convert_to_diff_attention( ) # Copy weights from old attention to new attention + new_attention.to(child.q_proj.weight.device) copy_attention_weights(child, new_attention, zero_init=zero_init) # Replace the layer diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 7b3db19ab..473556445 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -70,13 +70,12 @@ class LlamaDifferentialAttention(nn.Module): self.base_num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads - self.scaling = self.head_dim**-0.5 self.layer_idx = layer_idx self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True - dtype = getattr(config, "torch_dtype", torch.float32) + dtype = torch.float32 # For Q1 and Q2 self.q_proj = nn.Linear( @@ -111,7 +110,10 @@ class LlamaDifferentialAttention(nn.Module): ) # Initialize differential attention parameters - self.lambda_init = lambda_init_fn(self.layer_idx) + self.lambda_init = nn.Parameter( + torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype), + requires_grad=False, + ) self.lambda_q1 = nn.Parameter( torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) ) @@ -197,6 +199,14 @@ class LlamaDifferentialAttention(nn.Module): self.head_dim ) + # Add this debug step right after computing attention weights in the forward pass + attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + if attention_mask is not None: causal_mask = attention_mask[:, :, :, : k1.shape[-2]] attn_weights1 = attn_weights1 + causal_mask diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py index 14117bf63..37ad0a981 100644 --- a/src/axolotl/integrations/diff_transformer/patches.py +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -9,16 +9,13 @@ from .multihead_diffattn import ( ) -def patch_transformers(): +def patch_llama_attention_classes(): """Patch transformers to support differential attention""" # Add our attention class to the registry LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention - # Store original method for use in our patch - # original_autoset = PreTrainedModel._autoset_attn_implementation - @classmethod def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument config._attn_implementation_autoset = True # pylint: disable=protected-access diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c8e08468f..3b0dcbc2b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -444,6 +444,13 @@ class ModelLoader: patch_mistral_cross_entropy() + if self.cfg.diff_attention: + from axolotl.integrations.diff_transformer.patches import ( + patch_llama_attention_classes, + ) + + patch_llama_attention_classes() + def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -1058,10 +1065,6 @@ class ModelLoader: integrate_rope_embeddings() def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - from axolotl.integrations.diff_transformer.patches import patch_transformers - - patch_transformers() - self.apply_patches() self.set_auto_model_loader() self.set_device_map_config()