Compare commits

..

27 Commits

Author SHA1 Message Date
Wing Lian
6d3f4b9ab5 keep some softmax layers 2025-01-15 15:08:32 -05:00
Wing Lian
12aade921a various fixes 2025-01-15 13:25:16 -05:00
Wing Lian
198f01f902 remove bias term in phi and add custom modeling code 2025-01-15 13:25:16 -05:00
Wing Lian
2e6265090f remove lr_groups from other branch 2025-01-15 13:25:16 -05:00
Wing Lian
1c5b78621c fix forward sig
more fixes
2025-01-15 13:25:12 -05:00
Dan Saunders
2717b97103 adding yaml dumper preserving input config format 2024-12-20 20:40:31 +00:00
Dan Saunders
e0adf11b76 removing extra pytest xdist args 2024-12-20 20:40:26 +00:00
Dan Saunders
544f2a8a27 moving tests around for flash_attn install 2024-12-20 20:39:53 +00:00
Dan Saunders
d4e29e5b67 adding split_heads argument for retaining original (Q, K) dimensionanlity 2024-12-20 20:39:53 +00:00
Dan Saunders
80ba0d8dd1 isolating problematic test 2024-12-20 20:39:53 +00:00
Dan Saunders
dda9b25994 fixes post-rebase 2024-12-20 20:39:53 +00:00
Dan Saunders
0e9c0c6680 plugin implementation 2024-12-20 20:39:53 +00:00
Dan Saunders
b7cc117394 convert-differential-transformer test coverage 2024-12-20 20:39:53 +00:00
Dan Saunders
1fadc5cfe5 duplicate code ignore 2024-12-20 20:39:53 +00:00
Dan Saunders
6425d052bc differential flash attention 2; cleanup 2024-12-20 20:39:53 +00:00
Dan Saunders
594c42f169 moving monkeypatch 2024-12-20 20:39:53 +00:00
Dan Saunders
ae494776e4 pre-commit fix 2024-12-20 20:39:53 +00:00
Dan Saunders
503c4e9ffa fix model save / load logic 2024-12-20 20:39:53 +00:00
Dan Saunders
845dbede53 various improvemnents 2024-12-20 20:39:53 +00:00
Dan Saunders
7108ca72b4 various improvemnents 2024-12-20 20:39:53 +00:00
Dan Saunders
af1d8d69af training fixes, patching, minor cleanup 2024-12-20 20:39:53 +00:00
Dan Saunders
e162d36fe9 adding CLI command for convert-diff-transformer 2024-12-20 20:39:53 +00:00
Dan Saunders
7af20b52d6 Adding script for doing conversion; fixes and updates 2024-12-20 20:39:53 +00:00
Dan Saunders
866d7b3040 initial diff attn layer / model conversion implementation (support for llama arch) 2024-12-20 20:39:53 +00:00
Dan Saunders
23ac14540b Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2024-12-20 20:39:53 +00:00
Wing Lian
42bd32a233 add outputs (symlink) to gitignore [skip ci] (#2205) 2024-12-19 20:14:43 -05:00
Dan Saunders
5b8fb5e939 remove cicd pytest xdist args (#2201)
* remove cicd pytest xdist args

* Delete outputs
2024-12-19 11:44:53 -05:00
44 changed files with 2590 additions and 47 deletions

4
.gitignore vendored
View File

@@ -1,6 +1,7 @@
**/axolotl.egg-info
configs
last_run_prepared/
outputs
.vscode
_site/
@@ -185,3 +186,6 @@ out/
# vim
*.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union
import fire
from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

View File

@@ -0,0 +1,207 @@
"""CLI to convert a transformers model's attns to diff attns."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
def convert_diff_transformer(cfg, cli_args, config_path):
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
model = convert_to_diff_attn(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -0,0 +1,197 @@
"""CLI to convert a transformers model's attns to rala attns."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.rala.convert import convert_to_rala
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
def convert_rala(cfg, cli_args, config_path):
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("attention layers to RALA attention")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
# Convert attention
try:
model = convert_to_rala(
model=model,
zero_init=cli_args.zero_init,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["rala_attention"] = True
plugin_class = "axolotl.integrations.rala.RalaPlugin"
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
if cfg.rala_attention:
cfg.rala_attention = False
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_rala(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -12,7 +12,12 @@ from axolotl.cli.utils import (
build_command,
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.common.cli import (
ConvertDiffTransformerCliArgs,
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -77,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -240,6 +248,32 @@ def merge_lora(
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_diff_transformer import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_rala(config: str, **kwargs):
"""Convert model attention layers to RALA attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_rala import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")

View File

@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator

View File

@@ -4,7 +4,7 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
@@ -18,7 +18,7 @@ LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
dataclass with arguments for preprocessing only
"""
debug: bool = field(default=False)
@@ -31,7 +31,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
dataclass with various non-training arguments
"""
debug: bool = field(default=False)
@@ -46,7 +46,7 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
dataclass with various evaluation arguments
"""
debug: bool = field(default=False)
@@ -54,10 +54,22 @@ class EvaluateCliArgs:
debug_num_examples: int = field(default=0)
@dataclass
class ConvertDiffTransformerCliArgs:
"""
dataclass with arguments for convert-diff-transformer CLI
"""
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""
@@ -481,7 +481,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"to_weight_decay": {}, # LayerNorm except bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}

View File

@@ -9,12 +9,11 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -62,8 +61,9 @@ def evaluate_dataset(
return metrics
# pylint: disable=duplicate-code
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
@@ -79,16 +79,11 @@ def evaluate(
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load model
LOG.debug("loading model for evaluation...")
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Load processor for multimodal models if needed
processor = None
@@ -100,12 +95,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer
trainer = setup_trainer(
cfg,

View File

@@ -75,6 +75,21 @@ class BasePlugin:
None
"""
def set_attn_config(
self, cfg, model_kwargs, model_config
): # pylint: disable=unused-argument
"""
Sets attention configuration for the model.
Parameters:
cfg (dict): The configuration for the plugin.
model_kwargs (dict): The model kwargs for the plugin.
model_config (object): The model configuration.
Returns:
None
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -304,6 +319,18 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def set_attn_config(self, cfg, model_kwargs, model_config):
"""
modifies the attention configuration of the model kwargs for loading
Parameters:
cfg (dict): The configuration for the plugins.
model_kwargs (dict): The model's kwargs for construction the model
model_config (dict): The model's configuration.
"""
for plugin in self.plugins.values():
plugin.set_attn_config(cfg, model_kwargs, model_config)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -0,0 +1,10 @@
# Differential Transformer
### Usage
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
```

View File

@@ -0,0 +1,25 @@
"""Definition of differential transformer plugin."""
import logging
from axolotl.integrations.base import BasePlugin
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""
Plugin for differential transformer integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()

View File

@@ -0,0 +1,14 @@
"""Module for handling differential transfomer input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
diff_attention: Optional[bool] = None

View File

@@ -0,0 +1,130 @@
"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
nn.init.zeros_(new_attn.lambda_init)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
# Choose appropriate differential attention class
attention_class = ATTENTION_MAPPING[type(child)]
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(model)
logger.info(f"Converted {layer_idx} attention layers to differential attention")
return model

View File

@@ -0,0 +1,375 @@
"""Re-implemention of differential attention."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any, Optional, Tuple
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class DifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""
def __init__(self, config: Any, layer_idx: int):
super().__init__()
self._init_config(config, layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization(config)
def _init_config(self, config: Any, layer_idx: int):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads // 2
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even.
self.heads_per_component = self.base_num_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
# Double projection mode
self.head_dim = config.hidden_size // config.num_attention_heads
self.heads_per_component = self.base_num_heads
self.value_head_dim = self.head_dim
def _init_projections(self):
"""Initialize Q, K, V projections."""
if self.split_heads:
# Split heads mode - single projections
q_out_dim = self.hidden_size
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.hidden_size * 2
k_out_dim = (
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_differential_params(self):
"""Initialize differential attention parameters."""
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(
self.max_position_embeddings, self.head_dim, self.rope_theta
)
def _init_normalization(self, config):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
bsz, q_len, _ = hidden_states.size()
# Project and split
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
# Reshape
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self, q1, q2, k1, k2, position_ids, position_embeddings
):
"""Apply rotary embeddings to queries and keys."""
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q1.size(-2), device=q1.device)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
"""Handle caching for autoregressive generation."""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
return k1, k2, v
def _compute_lambda(self, q1):
"""Compute lambda values for differential attention."""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
return lambda_1 - lambda_2 + self.lambda_init
def _process_attention_output(self, attn, bsz, q_len):
"""Process and project attention output."""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(DifferentialAttentionBase):
"""Standard implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Standard attention computation
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
if output_attentions:
return attn, attn1 - lambda_full * attn2, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# SDPA-specific attention computation
causal_mask = (
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
attn1 = F.scaled_dot_product_attention(
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
attn2 = F.scaled_dot_product_attention(
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Flash Attention specific processing
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.attention_dropout if self.training else 0.0
if self.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
return attn, None, past_key_value

View File

@@ -0,0 +1,34 @@
"""Definition of RALA plugin."""
import logging
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention
LOG = logging.getLogger(__name__)
class RalaPlugin(BasePlugin):
"""
Plugin for Rala integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.rala_attention:
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def set_attn_config(self, cfg, model_kwargs, model_config):
if cfg.rala_attention:
model_kwargs["attn_implementation"] = "rala"

View File

@@ -0,0 +1,14 @@
"""Module for handling RALA input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class RalaArgs(BaseModel):
"""Input args for RALA."""
rala_attention: Optional[bool] = None

View File

@@ -0,0 +1,12 @@
"""
Rala config class
"""
from transformers import LlamaConfig
class LlamaRalaConfig(LlamaConfig):
"""
Configuration for LlamaRala model
"""
softmax_every: int = 6 # every 8th layer applies softmax

View File

@@ -0,0 +1,597 @@
# Copyright 2024-2025 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Apache License 2.0 (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
Custom modeling code for RALA Llama
"""
from typing import List, Optional, Tuple, Union, Unpack
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Cache, GenerationMixin, LlamaModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
KwargsForCausalLM,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
from .configuration_rala import LlamaRalaConfig
def kappa(x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
"""
The paper uses κ(x) = ELU(x) + 1.
x is assumed to be [batch, n_heads, seq_len, head_dim].
"""
return F.elu(x) + 1
class LlamaRALAAttention(nn.Module):
"""
LlamaAttention replaced with Rank-Augmented Linear Attention (RALA).
Adapted from the standard LlamaAttention for demonstration.
**Not** a fully drop-in replacement if you need caching/TP.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# Same Q, K, V, output projections
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=config.attention_bias
)
# We will preserve rope usage
self._init_rope()
# A simple φ-projection for RALA:
# The paper uses φ(x) as a linear transform or identity. We'll do a linear:
self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_rope(self):
# Standard Llama rope logic
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
"""
RALA forward pass.
This version omits incremental decoding with `past_key_value` for simplicity
(linear attention caching is non-trivial).
"""
bsz, q_len, _ = hidden_states.size()
# Standard Q, K, V
query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim]
key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim]
value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim]
# Reshape to [b, n_heads, seq_len, head_dim]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Apply RoPE (rotary embeddings) just as in standard Llama
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# 4. If we have a past_key_value (Cache object), let it update / append
if past_key_value is not None:
# This is the normal Llama pattern
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# The .update() method returns updated (key_states, value_states)
# and typically updates internal buffers. It may also store `layer_idx` data.
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# If you still want to handle the repeated KV for multi-group setups:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Now we apply RALA.
# 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim]
Q_kappa = kappa(query_states) # pylint: disable=invalid-name
K_kappa = kappa(key_states) # pylint: disable=invalid-name
# 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim]
# The paper denotes Q_g = (1/N) Σ_i Q_kappa_i
seq_len_float = float(q_len) # for scaling
Q_g = Q_kappa.mean( # pylint: disable=invalid-name
dim=2
) # [b, n_heads, head_dim]
# 3) Compute alpha_j for each token j in [0..seq_len-1]
# alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len]
# Dot product over head_dim
# K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim]
# We'll do an einsum or transpose to produce logits [b, n_heads, seq_len]
# Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len]
# logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len]
logits = (Q_g.unsqueeze(2) * K_kappa).sum(
dim=-1
) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work
# 4) Incorporate causal or padding mask if provided.
# In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar.
# For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits.
# Caution: This might not replicate strict causal linear attention. It's a best-effort approach.
if attention_mask is not None:
# Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf
# We want shape [b, n_heads, seq_len], so we can broadcast accordingly:
# e.g., attention_mask: [b, 1, q_len, seq_len]
# We pick the slice that corresponds to q_len vs. kv_len.
# Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`.
# We'll do something like:
if attention_mask.dim() == 4:
# attention_mask: [b, 1, q_len, kv_len]
# if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims
mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len]
# we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed
# but in this snippet, we do a single alpha_j for each j *per head*,
# ignoring per-token Q_i. So there's a mismatch.
# A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i.
# That is approximate. We'll just pick the first row of q_len, or do min across i dimension...
# For demonstration, let's sum or min across i dimension to see if j is valid for ANY i.
# Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j.
# We'll just do a rough approach, e.g. mask = min across the q_len dimension:
mask_1d = torch.min(mask_2d, dim=1)[
0
] # [b, seq_len], picking the worst mask across query positions
# broadcast for n_heads
mask_1d = mask_1d.unsqueeze(1).expand(
-1, self.num_heads, -1
) # [b, n_heads, seq_len]
logits = logits + mask_1d
else:
# Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len].
mask_1d = attention_mask # [b, seq_len]
mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1)
logits = logits + mask_1d
alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len]
# multiply by seq_len per the formula
alpha = alpha * seq_len_float
# 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j)
# The paper shows a d×d matrix formed per head.
# K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim]
# For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d
# Then multiply by alpha_j and sum over j.
# We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d]
# alpha: [b, n_heads, seq_len].
value_states_ = value_states # [b, n_heads, seq_len, head_dim]
outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_)
# Explanation:
# - 'bnhs' is alpha (batch, n_heads, seq_len)
# - 'bnhsd' is K_kappa (b,n_heads,seq_len, d)
# - 'bnhsf' is V (b,n_heads,seq_len, d)
# We want [b,n_heads,d,f], which is the d×d matrix per head.
# Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d].
# The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d].
# Let's do a simpler approach:
# outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j).
# = "bnhs,bnhsd,bnhsf -> bnhdf"
# means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d)
# We want to produce (b,n,h,d,d).
# So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf':
# alpha indexes b,n,h,s
# K_kappa indexes b,n,h,s,d => K_kappa_j
# V indexes b,n,h,s,f => V_j
# The resulting shape is (b,n,h,d,f). Great.
# 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ]
# Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d].
# We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d]
# Then multiply elementwise by φ(X_i).
# But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim].
# We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum:
# first, compute φ(X):
# X is the original hidden_states: [b, seq_len, d_model]
X_phi = self.phi( # pylint: disable=invalid-name
hidden_states
) # [b, seq_len, d_model]
X_phi = X_phi.view( # pylint: disable=invalid-name
bsz, q_len, self.num_heads, self.head_dim
) # [b, s, n, d]
X_phi = X_phi.transpose(1, 2) # [b, n, s, d] # pylint: disable=invalid-name
# Now for each i in [0..q_len-1], we do a matrix multiply:
# result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d].
# We'll do:
result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d]
# Then elementwise multiply by φ(X_i):
context_layer = X_phi * result_attn # [b,n,s,d]
# Finally, reorder to [b, s, n, d] -> [b, s, n*d]
context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d]
context_layer = context_layer.view(bsz, q_len, self.hidden_size)
# One last linear projection:
attn_output = self.o_proj(context_layer)
if output_attentions:
# alpha => [b, n_heads, (past_len + q_len)]
attn_weights = alpha
else:
attn_weights = None
# Return 3-tuple: (attn_output, attn_weights, past_key_value)
return attn_output, attn_weights, past_key_value
class LlamaRalaDecoderLayer(nn.Module):
"""
LlamaDecoderLayer with RALA support
"""
def __init__(self, config: LlamaRalaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@classmethod
def is_layer_idx_softmax(
cls, num_hidden_layers: int, layer_idx: int, softmax_every: int
) -> bool:
inner_layers = num_hidden_layers - 2
if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers:
softmax_start_idx = 1
elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers:
layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers:
layer_group_size = 1 + softmax_every * (inner_layers // softmax_every)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every))
softmax_layers.add(0)
softmax_layers.add(num_hidden_layers - 1)
return layer_idx in softmax_layers
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
if use_cache:
outputs += (present_key_value,) # type: ignore
return outputs # type: ignore
class LlamaRalaModel(LlamaModel):
"""
LlamaModel with RALA support
"""
config_class = LlamaRalaConfig
def __init__(self, config: LlamaRalaConfig):
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaRalaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
"""
LlamaForCausalLM with RALA support
"""
config_class = LlamaRalaConfig
_no_split_modules = ["LlamaRalaDecoderLayer"]
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = LlamaRalaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM], # type: ignore
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -0,0 +1,104 @@
"""
conversion for llama models to use RALA attention
"""
import logging
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaRALAAttention,
}
def copy_attention_weights(
old_attn,
new_attn,
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new RALA layer.
Copies q, k, v, o
"""
new_attn.q_proj.weight.data.copy_(old_attn.q_proj.weight.data)
new_attn.k_proj.weight.data.copy_(old_attn.k_proj.weight.data)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.phi.weight)
else:
nn.init.normal_(new_attn.phi.weight)
if new_attn.phi.bias:
nn.init.normal_(new_attn.phi.bias)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_rala(
model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
def convert_module(module, softmax_every, num_hidden_layers):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
decoder_layer_idx = child.layer_idx
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
num_hidden_layers, decoder_layer_idx, softmax_every
):
continue
# Choose appropriate differential attention class
# pylint: disable=duplicate-code
attention_class = ATTENTION_MAPPING[type(child)]
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child, softmax_every, num_hidden_layers)
model.config.softmax_every = softmax_every_n
convert_module(model, softmax_every_n, model.config.num_hidden_layers)
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
model.config.architectures = [
"LlamaRalaForCausalLM",
]
model.config.model_type = "llama_rala"
model.config.auto_map = {
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
"AutoModel": "llama.modeling_rala.LlamaRalaModel",
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
}
return model

View File

@@ -0,0 +1,280 @@
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Cache
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
def kappa(x: torch.Tensor) -> torch.Tensor:
"""
The paper uses κ(x) = ELU(x) + 1.
x is assumed to be [batch, n_heads, seq_len, head_dim].
"""
return F.elu(x) + 1
class LlamaRALAAttention(nn.Module):
"""
LlamaAttention replaced with Rank-Augmented Linear Attention (RALA).
Adapted from the standard LlamaAttention for demonstration.
**Not** a fully drop-in replacement if you need caching/TP.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# Same Q, K, V, output projections
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=config.attention_bias
)
# We will preserve rope usage
self._init_rope()
# A simple φ-projection for RALA:
# The paper uses φ(x) as a linear transform or identity. We'll do a linear:
self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def _init_rope(self):
# Standard Llama rope logic
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
"""
RALA forward pass.
This version omits incremental decoding with `past_key_value` for simplicity
(linear attention caching is non-trivial).
"""
bsz, q_len, _ = hidden_states.size()
# Standard Q, K, V
query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim]
key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim]
value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim]
# Reshape to [b, n_heads, seq_len, head_dim]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Apply RoPE (rotary embeddings) just as in standard Llama
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# If you still want to handle the repeated KV for multi-group setups:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Now we apply RALA.
# 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim]
Q_kappa = kappa(query_states)
K_kappa = kappa(key_states)
# 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim]
# The paper denotes Q_g = (1/N) Σ_i Q_kappa_i
seq_len_float = float(q_len) # for scaling
Q_g = Q_kappa.mean(dim=2) # [b, n_heads, head_dim]
# 3) Compute alpha_j for each token j in [0..seq_len-1]
# alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len]
# Dot product over head_dim
# K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim]
# We'll do an einsum or transpose to produce logits [b, n_heads, seq_len]
# Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len]
# logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len]
logits = (Q_g.unsqueeze(2) * K_kappa).sum(
dim=-1
) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work
# 4) Incorporate causal or padding mask if provided.
# In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar.
# For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits.
# Caution: This might not replicate strict causal linear attention. It's a best-effort approach.
if attention_mask is not None:
# Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf
# We want shape [b, n_heads, seq_len], so we can broadcast accordingly:
# e.g., attention_mask: [b, 1, q_len, seq_len]
# We pick the slice that corresponds to q_len vs. kv_len.
# Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`.
# We'll do something like:
if attention_mask.dim() == 4:
# attention_mask: [b, 1, q_len, kv_len]
# if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims
mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len]
# we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed
# but in this snippet, we do a single alpha_j for each j *per head*,
# ignoring per-token Q_i. So there's a mismatch.
# A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i.
# That is approximate. We'll just pick the first row of q_len, or do min across i dimension...
# For demonstration, let's sum or min across i dimension to see if j is valid for ANY i.
# Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j.
# We'll just do a rough approach, e.g. mask = min across the q_len dimension:
mask_1d = torch.min(mask_2d, dim=1)[
0
] # [b, seq_len], picking the worst mask across query positions
# broadcast for n_heads
mask_1d = mask_1d.unsqueeze(1).expand(
-1, self.num_heads, -1
) # [b, n_heads, seq_len]
logits = logits + mask_1d
else:
# Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len].
mask_1d = attention_mask # [b, seq_len]
mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1)
logits = logits + mask_1d
alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len]
# multiply by seq_len per the formula
alpha = alpha * seq_len_float
# 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j)
# The paper shows a d×d matrix formed per head.
# K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim]
# For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d
# Then multiply by alpha_j and sum over j.
# We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d]
# alpha: [b, n_heads, seq_len].
value_states_ = value_states # [b, n_heads, seq_len, head_dim]
outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_)
# Explanation:
# - 'bnhs' is alpha (batch, n_heads, seq_len)
# - 'bnhsd' is K_kappa (b,n_heads,seq_len, d)
# - 'bnhsf' is V (b,n_heads,seq_len, d)
# We want [b,n_heads,d,f], which is the d×d matrix per head.
# Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d].
# The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d].
# Let's do a simpler approach:
# outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j).
# = "bnhs,bnhsd,bnhsf -> bnhdf"
# means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d)
# We want to produce (b,n,h,d,d).
# So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf':
# alpha indexes b,n,h,s
# K_kappa indexes b,n,h,s,d => K_kappa_j
# V indexes b,n,h,s,f => V_j
# The resulting shape is (b,n,h,d,f). Great.
# 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ]
# Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d].
# We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d]
# Then multiply elementwise by φ(X_i).
# But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim].
# We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum:
# first, compute φ(X):
# X is the original hidden_states: [b, seq_len, d_model]
X_phi = self.phi(hidden_states) # [b, seq_len, d_model]
X_phi = X_phi.view(bsz, q_len, self.num_heads, self.head_dim) # [b, s, n, d]
X_phi = X_phi.transpose(1, 2) # [b, n, s, d]
# Now for each i in [0..q_len-1], we do a matrix multiply:
# result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d].
# We'll do:
result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d]
# Then elementwise multiply by φ(X_i):
context_layer = X_phi * result_attn # [b,n,s,d]
# Finally, reorder to [b, s, n, d] -> [b, s, n*d]
context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d]
context_layer = context_layer.view(bsz, q_len, self.hidden_size)
# One last linear projection:
attn_output = self.o_proj(context_layer)
# Not returning a standard attn_weights.
# If you want to return alpha as "attention," we can do so:
if output_attentions:
# alpha: [b, n_heads, seq_len], but note it's only the "global" weighting of each key,
# not a (q_len x kv_len) map like standard attention.
attn_weights = alpha
else:
attn_weights = None
# We omit cache / past_key_value returns to keep it simpler.
return attn_output, attn_weights, None

View File

@@ -0,0 +1,49 @@
"""Patches related to differential transformers implementation."""
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.diff_transformer.diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
def patch_llama_attention_classes():
"""Patch transformers to support differential attention"""
# Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2
@classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
config._attn_implementation_autoset = True # pylint: disable=protected-access
attn_implementation = getattr(config, "_attn_implementation", None)
valid_impls = [
None,
"eager",
"sdpa",
"flash_attention_2",
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
"rala",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message + ".")
return config
# Apply patch
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
new_autoset
)

View File

@@ -48,6 +48,7 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -375,8 +376,6 @@ class ModelLoader:
def apply_patches(self) -> None:
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
@@ -713,24 +712,53 @@ class ModelLoader:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if self.cfg.differentiaion:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
"differential_eager"
)
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
plugin_manager = PluginManager.get_instance()
plugin_manager.set_attn_config(self.cfg, self.model_kwargs, self.model_config)
def build_model(self, qlora_fsdp) -> bool:
def _configure_zero3_memory_efficient_loading():
"""
@@ -816,6 +844,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,

151
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,151 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representer
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch

View File

@@ -0,0 +1,31 @@
"""Shared fixtures for differential transformer conversion tests."""
import pytest
from click.testing import CliRunner
@pytest.fixture()
def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"val_set_size": 0.1,
"micro_batch_size": 1,
"sequence_len": 2048,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
@pytest.fixture
def cli_runner():
return CliRunner()

View File

@@ -0,0 +1,51 @@
"""End-to-end tests for differential transformer conversion and evaluation."""
# pylint: disable=duplicate-code
from pathlib import Path
import yaml
from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)

View File

@@ -0,0 +1,147 @@
"""End-to-end tests for differential transformer conversion."""
# pylint: disable=redefined-outer-name
# pylint: disable=duplicate-code
from pathlib import Path
from typing import Optional
from unittest.mock import patch
import pytest
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs()
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_debug(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()