fix model save / load logic
This commit is contained in:
@@ -75,7 +75,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
LOG.info("Converting to differential attention...")
|
||||
try:
|
||||
model = convert_to_diff_attention(model, cli_args.zero_init)
|
||||
model.to(model.device)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
@@ -4,7 +4,7 @@ shared module for cli specific things
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -67,7 +67,7 @@ class ConvertDiffTransformerCliArgs:
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
|
||||
):
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -9,13 +9,13 @@ from typing import Dict, Optional
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.models import load_processor
|
||||
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
@@ -63,7 +63,7 @@ def evaluate_dataset(
|
||||
|
||||
|
||||
def evaluate(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
@@ -83,12 +83,11 @@ def evaluate(
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
# Load model
|
||||
LOG.debug("loading model for evaluation...")
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
processor = None
|
||||
@@ -100,12 +99,6 @@ def evaluate(
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load model
|
||||
LOG.debug("loading model for evaluation...")
|
||||
model, _ = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
|
||||
@@ -59,7 +59,7 @@ def copy_attention_weights(
|
||||
nn.init.zeros_(new_attn.lambda_k1)
|
||||
nn.init.zeros_(new_attn.lambda_q2)
|
||||
nn.init.zeros_(new_attn.lambda_k2)
|
||||
new_attn.lambda_init = 0.0
|
||||
nn.init.zeros_(new_attn.lambda_init)
|
||||
|
||||
logger.debug(
|
||||
"Copied positive attention weights from %s to %s",
|
||||
@@ -105,6 +105,7 @@ def convert_to_diff_attention(
|
||||
)
|
||||
|
||||
# Copy weights from old attention to new attention
|
||||
new_attention.to(child.q_proj.weight.device)
|
||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||
|
||||
# Replace the layer
|
||||
|
||||
@@ -70,13 +70,12 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
dtype = getattr(config, "torch_dtype", torch.float32)
|
||||
dtype = torch.float32
|
||||
|
||||
# For Q1 and Q2
|
||||
self.q_proj = nn.Linear(
|
||||
@@ -111,7 +110,10 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Initialize differential attention parameters
|
||||
self.lambda_init = lambda_init_fn(self.layer_idx)
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
|
||||
)
|
||||
@@ -197,6 +199,14 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
# Add this debug step right after computing attention weights in the forward pass
|
||||
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
|
||||
attn_weights1 = attn_weights1 + causal_mask
|
||||
|
||||
@@ -9,16 +9,13 @@ from .multihead_diffattn import (
|
||||
)
|
||||
|
||||
|
||||
def patch_transformers():
|
||||
def patch_llama_attention_classes():
|
||||
"""Patch transformers to support differential attention"""
|
||||
|
||||
# Add our attention class to the registry
|
||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||
|
||||
# Store original method for use in our patch
|
||||
# original_autoset = PreTrainedModel._autoset_attn_implementation
|
||||
|
||||
@classmethod
|
||||
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
|
||||
config._attn_implementation_autoset = True # pylint: disable=protected-access
|
||||
|
||||
@@ -444,6 +444,13 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.diff_attention:
|
||||
from axolotl.integrations.diff_transformer.patches import (
|
||||
patch_llama_attention_classes,
|
||||
)
|
||||
|
||||
patch_llama_attention_classes()
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -1058,10 +1065,6 @@ class ModelLoader:
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
from axolotl.integrations.diff_transformer.patches import patch_transformers
|
||||
|
||||
patch_transformers()
|
||||
|
||||
self.apply_patches()
|
||||
self.set_auto_model_loader()
|
||||
self.set_device_map_config()
|
||||
|
||||
Reference in New Issue
Block a user