fix model save / load logic

This commit is contained in:
Dan Saunders
2024-12-17 04:43:08 +00:00
parent 845dbede53
commit 503c4e9ffa
7 changed files with 35 additions and 31 deletions

View File

@@ -75,7 +75,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(model, cli_args.zero_init)
model.to(model.device)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

View File

@@ -4,7 +4,7 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
@@ -67,7 +67,7 @@ class ConvertDiffTransformerCliArgs:
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -9,13 +9,13 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
@@ -63,7 +63,7 @@ def evaluate_dataset(
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
@@ -83,12 +83,11 @@ def evaluate(
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load model
LOG.debug("loading model for evaluation...")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Load processor for multimodal models if needed
processor = None
@@ -100,12 +99,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer
trainer = setup_trainer(
cfg,

View File

@@ -59,7 +59,7 @@ def copy_attention_weights(
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
new_attn.lambda_init = 0.0
nn.init.zeros_(new_attn.lambda_init)
logger.debug(
"Copied positive attention weights from %s to %s",
@@ -105,6 +105,7 @@ def convert_to_diff_attention(
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer

View File

@@ -70,13 +70,12 @@ class LlamaDifferentialAttention(nn.Module):
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
dtype = getattr(config, "torch_dtype", torch.float32)
dtype = torch.float32
# For Q1 and Q2
self.q_proj = nn.Linear(
@@ -111,7 +110,10 @@ class LlamaDifferentialAttention(nn.Module):
)
# Initialize differential attention parameters
self.lambda_init = lambda_init_fn(self.layer_idx)
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
)
@@ -197,6 +199,14 @@ class LlamaDifferentialAttention(nn.Module):
self.head_dim
)
# Add this debug step right after computing attention weights in the forward pass
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn_weights1 = attn_weights1 + causal_mask

View File

@@ -9,16 +9,13 @@ from .multihead_diffattn import (
)
def patch_transformers():
def patch_llama_attention_classes():
"""Patch transformers to support differential attention"""
# Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
# Store original method for use in our patch
# original_autoset = PreTrainedModel._autoset_attn_implementation
@classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
config._attn_implementation_autoset = True # pylint: disable=protected-access

View File

@@ -444,6 +444,13 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.diff_attention:
from axolotl.integrations.diff_transformer.patches import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1058,10 +1065,6 @@ class ModelLoader:
integrate_rope_embeddings()
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
from axolotl.integrations.diff_transformer.patches import patch_transformers
patch_transformers()
self.apply_patches()
self.set_auto_model_loader()
self.set_device_map_config()