various improvemnents
This commit is contained in:
@@ -13,7 +13,7 @@ from dotenv import load_dotenv
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.convert_attention")
|
||||
@@ -67,21 +67,23 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
)
|
||||
|
||||
# Test original model
|
||||
LOG.info("Testing original model...")
|
||||
orig_time, orig_text = test_inference(model, tokenizer)
|
||||
if cli_args.debug:
|
||||
LOG.info("Testing original model...")
|
||||
orig_time, orig_text = test_inference(model, tokenizer)
|
||||
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
try:
|
||||
model = convert_to_diff_attention(model)
|
||||
model = convert_to_diff_attention(model, cli_args.zero_init)
|
||||
model.to(model.device)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
# Test converted model
|
||||
LOG.info("Testing converted model...")
|
||||
conv_time, conv_text = test_inference(model, tokenizer)
|
||||
if cli_args.debug:
|
||||
LOG.info("Testing converted model...")
|
||||
conv_time, conv_text = test_inference(model, tokenizer)
|
||||
|
||||
# Save if requested
|
||||
if cfg.output_dir:
|
||||
@@ -106,30 +108,65 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
LOG.info("Not saving converted model to disk")
|
||||
LOG.info("Pass --output-dir path/to/save to save model")
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Conversion successful!\n"
|
||||
+ f"Original generation time: {orig_time:.2f}s\n"
|
||||
+ f"Converted generation time: {conv_time:.2f}s"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
if orig_text == conv_text:
|
||||
if cli_args.debug:
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Generations match!\n"
|
||||
+ f"Model generation: {orig_text}\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
Fore.RED
|
||||
+ "Generations do not match.\n"
|
||||
+ f"Original generation: {orig_text}\n"
|
||||
+ f"Converted generation: {conv_text}\n"
|
||||
+ "Conversion successful!\n"
|
||||
+ f"Original generation time: {orig_time:.2f}s\n"
|
||||
+ f"Converted generation time: {conv_time:.2f}s"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
if orig_text == conv_text:
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Generations match!\n"
|
||||
+ "Model generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
else:
|
||||
if cli_args.zero_init:
|
||||
LOG.info(
|
||||
Fore.RED
|
||||
+ "Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{conv_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
Fore.YELLOW
|
||||
+ "Generations do not match.\n"
|
||||
+ "Original generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{orig_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "Converted generation:\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ f"{conv_text}\n"
|
||||
+ "*" * 50
|
||||
+ "\n"
|
||||
+ "However, this is expected since --zero-init was not passed."
|
||||
+ Fore.RESET
|
||||
)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
@@ -139,7 +176,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
convert_diff_transformer(cfg, cli_args, config)
|
||||
|
||||
@@ -12,7 +12,7 @@ from axolotl.cli.utils import (
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
)
|
||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
@@ -242,7 +242,7 @@ def merge_lora(
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
|
||||
@@ -18,7 +18,7 @@ LOG = logging.getLogger("axolotl.common.cli")
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
dataclass with arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
@@ -31,7 +31,7 @@ class PreprocessCliArgs:
|
||||
@dataclass
|
||||
class TrainerCliArgs:
|
||||
"""
|
||||
dataclass representing the various non-training arguments
|
||||
dataclass with various non-training arguments
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
@@ -46,7 +46,7 @@ class TrainerCliArgs:
|
||||
@dataclass
|
||||
class EvaluateCliArgs:
|
||||
"""
|
||||
dataclass representing the various evaluation arguments
|
||||
dataclass with various evaluation arguments
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
@@ -54,6 +54,16 @@ class EvaluateCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertDiffTransformerCliArgs:
|
||||
"""
|
||||
dataclass with arguments for convert-diff-transformer CLI
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
def copy_attention_weights(
|
||||
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
|
||||
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
|
||||
zero_init: bool = True,
|
||||
zero_init: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Copy weights from old attention layer to new differential attention layer.
|
||||
@@ -68,7 +68,9 @@ def copy_attention_weights(
|
||||
)
|
||||
|
||||
|
||||
def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
def convert_to_diff_attention(
|
||||
model: PreTrainedModel, zero_init: bool
|
||||
) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
attention_patterns = (
|
||||
LlamaAttention,
|
||||
@@ -78,9 +80,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
)
|
||||
layer_idx = 0
|
||||
|
||||
# Get model dtype from existing weights
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
|
||||
def convert_module(module):
|
||||
nonlocal layer_idx
|
||||
|
||||
@@ -103,11 +102,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel:
|
||||
new_attention = attention_class(
|
||||
config=module.config if hasattr(module, "config") else model.config,
|
||||
layer_idx=layer_idx,
|
||||
dtype=model_dtype,
|
||||
)
|
||||
|
||||
# Copy weights from old attention to new attention
|
||||
copy_attention_weights(child, new_attention)
|
||||
copy_attention_weights(child, new_attention, zero_init=zero_init)
|
||||
|
||||
# Replace the layer
|
||||
setattr(module, name, new_attention)
|
||||
|
||||
@@ -60,11 +60,10 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self,
|
||||
config: Any,
|
||||
layer_idx: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Base model dimensions
|
||||
# Base model config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
@@ -77,6 +76,8 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
dtype = getattr(config, "torch_dtype", torch.float32)
|
||||
|
||||
# For Q1 and Q2
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Patches related to differential transformers implementation."""
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
|
||||
@@ -710,13 +710,6 @@ class ModelLoader:
|
||||
"""
|
||||
sample packing uses custom FA2 patch
|
||||
"""
|
||||
print(
|
||||
self.cfg.flash_attention,
|
||||
self.cfg.sdp_attention,
|
||||
self.cfg.eager_attention,
|
||||
self.cfg.diff_attention,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
@@ -761,9 +754,6 @@ class ModelLoader:
|
||||
"differential_eager"
|
||||
)
|
||||
|
||||
if "attn_implementation" in self.model_kwargs:
|
||||
print(self.model_kwargs["attn_implementation"])
|
||||
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user