fix model save / load logic

This commit is contained in:
Dan Saunders
2024-12-17 04:43:08 +00:00
parent 845dbede53
commit 503c4e9ffa
7 changed files with 35 additions and 31 deletions

View File

@@ -75,7 +75,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
LOG.info("Converting to differential attention...") LOG.info("Converting to differential attention...")
try: try:
model = convert_to_diff_attention(model, cli_args.zero_init) model = convert_to_diff_attention(model, cli_args.zero_init)
model.to(model.device) model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc: except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise raise

View File

@@ -4,7 +4,7 @@ shared module for cli specific things
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -67,7 +67,7 @@ class ConvertDiffTransformerCliArgs:
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: TrainerCliArgs, cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
): ):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)

View File

@@ -9,13 +9,13 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.models import load_processor
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src") src_dir = os.path.join(project_root, "src")
@@ -63,7 +63,7 @@ def evaluate_dataset(
def evaluate( def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets
@@ -83,12 +83,11 @@ def evaluate(
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf() set_pytorch_cuda_alloc_conf()
# Load tokenizer # Load model
LOG.debug( LOG.debug("loading model for evaluation...")
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True, model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
) model = model.to(cfg.device, dtype=cfg.torch_dtype)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed # Load processor for multimodal models if needed
processor = None processor = None
@@ -100,12 +99,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer # Set up trainer
trainer = setup_trainer( trainer = setup_trainer(
cfg, cfg,

View File

@@ -59,7 +59,7 @@ def copy_attention_weights(
nn.init.zeros_(new_attn.lambda_k1) nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2) nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2) nn.init.zeros_(new_attn.lambda_k2)
new_attn.lambda_init = 0.0 nn.init.zeros_(new_attn.lambda_init)
logger.debug( logger.debug(
"Copied positive attention weights from %s to %s", "Copied positive attention weights from %s to %s",
@@ -105,6 +105,7 @@ def convert_to_diff_attention(
) )
# Copy weights from old attention to new attention # Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init) copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer # Replace the layer

View File

@@ -70,13 +70,12 @@ class LlamaDifferentialAttention(nn.Module):
self.base_num_kv_heads = config.num_key_value_heads self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
self.scaling = self.head_dim**-0.5
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
dtype = getattr(config, "torch_dtype", torch.float32) dtype = torch.float32
# For Q1 and Q2 # For Q1 and Q2
self.q_proj = nn.Linear( self.q_proj = nn.Linear(
@@ -111,7 +110,10 @@ class LlamaDifferentialAttention(nn.Module):
) )
# Initialize differential attention parameters # Initialize differential attention parameters
self.lambda_init = lambda_init_fn(self.layer_idx) self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter( self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1)
) )
@@ -197,6 +199,14 @@ class LlamaDifferentialAttention(nn.Module):
self.head_dim self.head_dim
) )
# Add this debug step right after computing attention weights in the forward pass
attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(
self.head_dim
)
if attention_mask is not None: if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]] causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn_weights1 = attn_weights1 + causal_mask attn_weights1 = attn_weights1 + causal_mask

View File

@@ -9,16 +9,13 @@ from .multihead_diffattn import (
) )
def patch_transformers(): def patch_llama_attention_classes():
"""Patch transformers to support differential attention""" """Patch transformers to support differential attention"""
# Add our attention class to the registry # Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
# Store original method for use in our patch
# original_autoset = PreTrainedModel._autoset_attn_implementation
@classmethod @classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
config._attn_implementation_autoset = True # pylint: disable=protected-access config._attn_implementation_autoset = True # pylint: disable=protected-access

View File

@@ -444,6 +444,13 @@ class ModelLoader:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
if self.cfg.diff_attention:
from axolotl.integrations.diff_transformer.patches import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1058,10 +1065,6 @@ class ModelLoader:
integrate_rope_embeddings() integrate_rope_embeddings()
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
from axolotl.integrations.diff_transformer.patches import patch_transformers
patch_transformers()
self.apply_patches() self.apply_patches()
self.set_auto_model_loader() self.set_auto_model_loader()
self.set_device_map_config() self.set_device_map_config()