various improvemnents

This commit is contained in:
Dan Saunders
2024-12-13 15:03:45 -05:00
parent af1d8d69af
commit 7108ca72b4
7 changed files with 87 additions and 50 deletions

View File

@@ -13,7 +13,7 @@ from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
LOG = logging.getLogger("axolotl.cli.convert_attention")
@@ -67,21 +67,23 @@ def convert_diff_transformer(cfg, cli_args, config_path):
)
# Test original model
LOG.info("Testing original model...")
orig_time, orig_text = test_inference(model, tokenizer)
if cli_args.debug:
LOG.info("Testing original model...")
orig_time, orig_text = test_inference(model, tokenizer)
# Convert attention
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(model)
model = convert_to_diff_attention(model, cli_args.zero_init)
model.to(model.device)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
LOG.info("Testing converted model...")
conv_time, conv_text = test_inference(model, tokenizer)
if cli_args.debug:
LOG.info("Testing converted model...")
conv_time, conv_text = test_inference(model, tokenizer)
# Save if requested
if cfg.output_dir:
@@ -106,30 +108,65 @@ def convert_diff_transformer(cfg, cli_args, config_path):
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {orig_time:.2f}s\n"
+ f"Converted generation time: {conv_time:.2f}s"
+ Fore.RESET
)
if orig_text == conv_text:
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ f"Model generation: {orig_text}\n"
+ Fore.RESET
)
else:
LOG.info(
Fore.RED
+ "Generations do not match.\n"
+ f"Original generation: {orig_text}\n"
+ f"Converted generation: {conv_text}\n"
+ "Conversion successful!\n"
+ f"Original generation time: {orig_time:.2f}s\n"
+ f"Converted generation time: {conv_time:.2f}s"
+ Fore.RESET
)
if orig_text == conv_text:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
else:
if cli_args.zero_init:
LOG.info(
Fore.RED
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
else:
LOG.info(
Fore.YELLOW
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ "However, this is expected since --zero-init was not passed."
+ Fore.RESET
)
except Exception as exc:
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
raise
@@ -139,7 +176,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)

View File

@@ -12,7 +12,7 @@ from axolotl.cli.utils import (
build_command,
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -242,7 +242,7 @@ def merge_lora(
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""

View File

@@ -18,7 +18,7 @@ LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
dataclass with arguments for preprocessing only
"""
debug: bool = field(default=False)
@@ -31,7 +31,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
dataclass with various non-training arguments
"""
debug: bool = field(default=False)
@@ -46,7 +46,7 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
dataclass with various evaluation arguments
"""
debug: bool = field(default=False)
@@ -54,6 +54,16 @@ class EvaluateCliArgs:
debug_num_examples: int = field(default=0)
@dataclass
class ConvertDiffTransformerCliArgs:
"""
dataclass with arguments for convert-diff-transformer CLI
"""
debug: bool = field(default=False)
zero_init: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,

View File

@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
zero_init: bool = True,
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
@@ -68,7 +68,9 @@ def copy_attention_weights(
)
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
def convert_to_diff_attention(
model: PreTrainedModel, zero_init: bool
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
attention_patterns = (
LlamaAttention,
@@ -78,9 +80,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
)
layer_idx = 0
# Get model dtype from existing weights
model_dtype = next(model.parameters()).dtype
def convert_module(module):
nonlocal layer_idx
@@ -103,11 +102,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
dtype=model_dtype,
)
# Copy weights from old attention to new attention
copy_attention_weights(child, new_attention)
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)

View File

@@ -60,11 +60,10 @@ class LlamaDifferentialAttention(nn.Module):
self,
config: Any,
layer_idx: int,
dtype: torch.dtype,
):
super().__init__()
# Base model dimensions
# Base model config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
@@ -77,6 +76,8 @@ class LlamaDifferentialAttention(nn.Module):
self.rope_theta = config.rope_theta
self.is_causal = True
dtype = getattr(config, "torch_dtype", torch.float32)
# For Q1 and Q2
self.q_proj = nn.Linear(
self.hidden_size,

View File

@@ -1,4 +1,5 @@
"""Patches related to differential transformers implementation."""
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES

View File

@@ -710,13 +710,6 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
print(
self.cfg.flash_attention,
self.cfg.sdp_attention,
self.cfg.eager_attention,
self.cfg.diff_attention,
)
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
@@ -761,9 +754,6 @@ class ModelLoader:
"differential_eager"
)
if "attn_implementation" in self.model_kwargs:
print(self.model_kwargs["attn_implementation"])
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True