Compare commits
3 Commits
feat/linea
...
modal-upgr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b56cc18d5 | ||
|
|
5c3ac90669 | ||
|
|
353ba4e80b |
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -207,7 +207,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -248,7 +248,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
|
||||
@@ -29,7 +29,7 @@ datasets:
|
||||
type: chatml.intel
|
||||
- path: argilla/ultrafeedback-binarized-preferences
|
||||
split: train
|
||||
type: chatml
|
||||
type: chatml.argilla
|
||||
```
|
||||
|
||||
#### IPO
|
||||
|
||||
@@ -13,12 +13,6 @@ class PreprocessCliArgs:
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
download: Optional[bool] = field(default=True)
|
||||
iterable: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Use IterableDataset for streaming processing of large datasets"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||
LinearLlamaConfig,
|
||||
)
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
LinearLlamaForCausalLM,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
Convert attention to linear attention and perform attention transfer via distillation.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||
cfg.load_in_8bit = False
|
||||
cfg.load_in_4bit = False
|
||||
cfg.adapter = None
|
||||
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# freeze model
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# convert to linear llama
|
||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||
model.config, cfg.attention_config
|
||||
)
|
||||
model = LinearLlamaForCausalLM.from_llama(
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# set save_path, save tokenizer and model config.
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
# Get datasets
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# toggle attention to be trainable
|
||||
model.toggle_attention(train=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=(model, None, None),
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# drop base_attention + remove training attn
|
||||
model.toggle_attention(train=False)
|
||||
model.remove_base_attention()
|
||||
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
# save model
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||
|
||||
# cleanup
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# load cfg, force linearize and add plugin to linearize
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
linearize=True,
|
||||
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_linearize(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,17 +1,10 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
@@ -27,76 +20,6 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -137,21 +60,10 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--sweep",
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="YAML config for sweeping hyperparameters",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def train(
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
@@ -159,7 +71,6 @@ def train(
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
@@ -169,66 +80,35 @@ def train(
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -75,10 +75,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
)
|
||||
|
||||
|
||||
def do_cli(
|
||||
config: Union[Path, str] = Path("examples/"),
|
||||
**kwargs,
|
||||
) -> None:
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
|
||||
|
||||
|
||||
@@ -63,17 +63,11 @@ def load_datasets(
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = (
|
||||
hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,988 +0,0 @@
|
||||
"""
|
||||
module for customized trainers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
DPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
or self.args.lr_groups is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW(
|
||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
concatenated_batch = {}
|
||||
|
||||
max_length = max(
|
||||
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
||||
)
|
||||
# Concatenate positive and negative inputs
|
||||
concatenated_batch["input_ids"] = pad_to_length(
|
||||
inputs["input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
||||
inputs["rejected_input_ids"], max_length, pad_token
|
||||
)
|
||||
concatenated_batch["labels"] = pad_to_length(
|
||||
inputs["labels"], max_length, label_pad_token
|
||||
)
|
||||
concatenated_batch["rejected_labels"] = pad_to_length(
|
||||
inputs["rejected_labels"], max_length, label_pad_token
|
||||
)
|
||||
concatenated_batch["attention_mask"] = pad_to_length(
|
||||
inputs["attention_mask"], max_length, 0
|
||||
)
|
||||
concatenated_batch["rejected_attention_mask"] = pad_to_length(
|
||||
inputs["rejected_attention_mask"], max_length, 0
|
||||
)
|
||||
concatenated_batch["prompt_attention_mask"] = pad_to_length(
|
||||
inputs["prompt_attention_mask"], max_length, 0
|
||||
).to(device=device)
|
||||
|
||||
input_ids = torch.cat(
|
||||
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
concatenated_batch["attention_mask"],
|
||||
concatenated_batch["rejected_attention_mask"],
|
||||
],
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
labels = torch.cat(
|
||||
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
|
||||
).to(device=device)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
|
||||
}
|
||||
|
||||
def orpo_compute_custom_loss(self, logits, labels):
|
||||
logits = logits.contiguous()
|
||||
loss = 0.0
|
||||
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def orpo_compute_logps(
|
||||
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
||||
):
|
||||
# Get the shape of chosen_attention_mask[:, :-1]
|
||||
chosen_shape = chosen_attention_mask[:, :-1].shape
|
||||
|
||||
# Calculate the padding size
|
||||
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
||||
|
||||
# Pad prompt_attention_mask with zeros to match the desired shape
|
||||
prompt_attention_mask_padded = torch.nn.functional.pad(
|
||||
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
||||
)
|
||||
|
||||
# Perform the subtraction operation
|
||||
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
||||
|
||||
per_token_logps = torch.gather(
|
||||
logits[:, :-1, :].log_softmax(-1),
|
||||
dim=2,
|
||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
def orpo_compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||
inputs,
|
||||
label_pad_token=-100,
|
||||
pad_token=self.tokenizer.pad_token_id,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
|
||||
# Perform a single forward pass
|
||||
outputs = model(
|
||||
**{
|
||||
"input_ids": concat_inputs["input_ids"],
|
||||
"attention_mask": concat_inputs["attention_mask"],
|
||||
"labels": concat_inputs["labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Split the outputs for positive and negative examples
|
||||
outputs_pos, outputs_neg = outputs.logits.chunk(2)
|
||||
|
||||
# Calculate NLL loss
|
||||
pos_loss = self.orpo_compute_custom_loss(
|
||||
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
|
||||
)
|
||||
|
||||
# Calculate Log Probability
|
||||
pos_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
|
||||
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
|
||||
logits=outputs_pos,
|
||||
)
|
||||
neg_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
||||
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
|
||||
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
|
||||
logits=outputs_neg,
|
||||
)
|
||||
|
||||
# Calculate log odds
|
||||
log_odds = (pos_prob - neg_prob) - (
|
||||
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
||||
)
|
||||
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
||||
ratio = torch.log(sig_ratio)
|
||||
|
||||
# Calculate the Final Loss
|
||||
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
||||
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
||||
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
||||
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
return (loss, outputs_pos) if return_outputs else loss
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
"limit_all_gathers" in self.args.fsdp_config
|
||||
and self.args.fsdp_config["limit_all_gathers"]
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
@@ -1,264 +0,0 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_optimizer: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||
},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_zscore_base_temp: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
ORPO config for ORPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
||||
"""
|
||||
KTO config for KTO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
"""
|
||||
CPO config for CPO training
|
||||
"""
|
||||
|
||||
simpo_gamma: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "simpo gamma parameter"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
|
||||
"""
|
||||
PRM config for PRM training
|
||||
"""
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -51,17 +51,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 1_000
|
||||
|
||||
if (
|
||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||
and self.prompt_tokenizer.filter_rows
|
||||
):
|
||||
dataset = dataset.filter(
|
||||
self.prompt_tokenizer.filter_rows,
|
||||
num_proc=num_proc,
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
map_kwargs["batch_size"] = 100
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
@@ -73,24 +63,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = dataset.features.keys()
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
|
||||
@@ -111,17 +111,6 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
class: The class for the trainer.
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
@@ -223,17 +212,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
module_name, class_name = plugin_name.rsplit(".", 1)
|
||||
|
||||
# import the module
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as orig_exc:
|
||||
try:
|
||||
if not module_name.startswith("axolotl.integrations."):
|
||||
module = importlib.import_module("axolotl.integrations." + module_name)
|
||||
else:
|
||||
raise orig_exc
|
||||
except ModuleNotFoundError as exc:
|
||||
raise orig_exc from exc
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
# instantiate the class
|
||||
plugin_class = getattr(module, class_name)
|
||||
# create an instance of the class
|
||||
@@ -293,10 +272,8 @@ class PluginManager:
|
||||
ImportError: If the plugin module cannot be imported.
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Attempting to load plugin: {plugin_name}")
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
logging.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -369,22 +346,6 @@ class PluginManager:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
"""
|
||||
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
|
||||
Returns:
|
||||
object: The trainer class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
trainer_cls = plugin.get_trainer_cls(cfg)
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
"""
|
||||
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin init to add KD support to Axolotl.
|
||||
"""
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
|
||||
class KDPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for KD support in Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.kd.KDArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
@@ -1,37 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Plugin args for KD support.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KDArgs(BaseModel):
|
||||
"""
|
||||
Input args for knowledge distillation.
|
||||
"""
|
||||
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[
|
||||
bool
|
||||
] = None # whether to sample top k before softmax during KD
|
||||
@@ -1,201 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
Handle fields for logprob KD
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=None,
|
||||
train_on_eos=None,
|
||||
logprobs_field="logprobs",
|
||||
gen_temperature=1.0,
|
||||
kd_temperature=1.0,
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.gen_temperature = gen_temperature
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
roles_to_train=roles_to_train,
|
||||
train_on_eos=train_on_eos,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# batching doesn't work well for logprob data
|
||||
return False
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for _, token_pos_logprobs in enumerate(logprobs):
|
||||
# Initialize collections for logprobs and token_ids
|
||||
position_logprobs = []
|
||||
position_token_ids = []
|
||||
|
||||
# Process each token probability entry
|
||||
for entry in token_pos_logprobs:
|
||||
# Extract logprob value
|
||||
logprob = entry["logprob"]
|
||||
|
||||
# Parse token_id from the "token_id:###" format
|
||||
token_id = int(entry["token"].split(":")[1])
|
||||
|
||||
# Append to our collections
|
||||
position_logprobs.append(logprob)
|
||||
position_token_ids.append(token_id)
|
||||
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
position_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
if shift == 1:
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class KDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||
if logprobs_field := ds_cfg.get("logprobs_field"):
|
||||
strategy_params["logprobs_field"] = logprobs_field
|
||||
if gen_temperature := ds_cfg.get("temperature"):
|
||||
strategy_params["gen_temperature"] = gen_temperature
|
||||
if kd_temperature := cfg.get("kd_temperature"):
|
||||
strategy_params["kd_temperature"] = kd_temperature
|
||||
|
||||
return strategy_params
|
||||
|
||||
|
||||
load = KDStrategyLoader()
|
||||
@@ -1,255 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
DataCollator for axolotl to handle KD fields without using -inf for padding,
|
||||
and with a teacher_mask to identify padded positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Data collator for KD, including handling KD-specific fields.
|
||||
|
||||
This version avoids using -inf and instead uses a large negative value for padding
|
||||
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
model: Optional[Any] = None
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
("labels", self.label_pad_token_id),
|
||||
("position_ids", self.position_pad_token_id),
|
||||
]:
|
||||
if feature_name in features[0]:
|
||||
feat = [f[feature_name] for f in features]
|
||||
max_len = max(len(x) for x in feat)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_len = (
|
||||
(max_len + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
) * self.pad_to_multiple_of
|
||||
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
|
||||
if isinstance(f[feature_name], list):
|
||||
f[feature_name] = (
|
||||
f[feature_name] + remainder
|
||||
if padding_side == "right"
|
||||
else remainder + f[feature_name]
|
||||
)
|
||||
else:
|
||||
# If they are numpy arrays
|
||||
if padding_side == "right":
|
||||
f[feature_name] = np.concatenate(
|
||||
[f[feature_name], remainder]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
f[feature_name] = np.concatenate(
|
||||
[remainder, f[feature_name]]
|
||||
).astype(np.int64)
|
||||
|
||||
# Handle target_logprobs and target_token_ids manually
|
||||
target_logprobs_list = []
|
||||
target_token_ids_list = []
|
||||
target_mask_list = []
|
||||
has_teacher_data = ("target_logprobs" in features[0]) and (
|
||||
"target_token_ids" in features[0]
|
||||
)
|
||||
|
||||
if has_teacher_data:
|
||||
# Extract and remove from features
|
||||
for f in features: # pylint: disable=invalid-name
|
||||
target_logprobs_list.append(f.pop("target_logprobs"))
|
||||
target_token_ids_list.append(f.pop("target_token_ids"))
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
padded_target_logprobs = []
|
||||
padded_target_token_ids = []
|
||||
padded_teacher_mask_list = []
|
||||
|
||||
for t_logprobs, t_ids, t_mask in zip(
|
||||
target_logprobs_list, target_token_ids_list, target_mask_list
|
||||
):
|
||||
t_logprobs_padded = []
|
||||
t_ids_padded = []
|
||||
t_mask_padded = []
|
||||
|
||||
for lp, ids, mask in zip( # pylint: disable=invalid-name
|
||||
t_logprobs, t_ids, t_mask
|
||||
):
|
||||
lp_len = len(lp)
|
||||
if lp_len < max_k:
|
||||
# Use -1e9 for padding logprobs and 0 for token_ids
|
||||
pad_len = max_k - lp_len
|
||||
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
|
||||
ids = ids + [0] * pad_len
|
||||
mask = mask + [0] * pad_len
|
||||
else:
|
||||
lp = lp[:max_k] # pylint: disable=invalid-name
|
||||
ids = ids[:max_k]
|
||||
mask = mask[:max_k]
|
||||
|
||||
t_logprobs_padded.append(lp)
|
||||
t_ids_padded.append(ids)
|
||||
t_mask_padded.append(mask)
|
||||
|
||||
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
|
||||
if seq_len_diff > 0:
|
||||
# Pad sequences fully if needed
|
||||
t_logprobs_padded.extend(
|
||||
[[-1e9] * max_k for _ in range(seq_len_diff)]
|
||||
)
|
||||
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
|
||||
|
||||
padded_target_logprobs.append(t_logprobs_padded)
|
||||
padded_target_token_ids.append(t_ids_padded)
|
||||
padded_teacher_mask_list.append(t_mask_padded)
|
||||
|
||||
# Convert to tensors
|
||||
padded_target_logprobs = torch.tensor(
|
||||
padded_target_logprobs, dtype=torch.float
|
||||
)
|
||||
padded_target_token_ids = torch.tensor(
|
||||
padded_target_token_ids, dtype=torch.long
|
||||
)
|
||||
padded_teacher_mask_list = torch.tensor(
|
||||
padded_teacher_mask_list, dtype=torch.int
|
||||
)
|
||||
|
||||
# Pad using tokenizer for regular fields
|
||||
features = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
# Add back teacher data if present
|
||||
if has_teacher_data:
|
||||
features["target_logprobs"] = padded_target_logprobs
|
||||
features["target_token_ids"] = padded_target_token_ids
|
||||
features["target_mask"] = padded_teacher_mask_list
|
||||
|
||||
# Prepare decoder_input_ids if the model supports it
|
||||
if (
|
||||
"labels" in features
|
||||
and self.model is not None
|
||||
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||
):
|
||||
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||
labels=features["labels"]
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
"""
|
||||
Collator for multipack (batch of sub-batches) specifically for KD.
|
||||
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
"""
|
||||
Expects that `features` could be either:
|
||||
- a single list of dicts, OR
|
||||
- a list of lists of dicts (the "sub-batches" to be packed).
|
||||
"""
|
||||
# 1) If we are *not* dealing with multiple sequences per batch element,
|
||||
# just pass straight to parent.
|
||||
if not isinstance(features[0], list):
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
|
||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||
out_features = [{} for _ in features]
|
||||
|
||||
for i, sub_features in enumerate(features):
|
||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||
# We'll merge them into out_features[i].
|
||||
#
|
||||
# NOTE: You can customize how you combine fields as needed (e.g. summation
|
||||
# or offset for attention_mask). Below is a straightforward concatenation/extension.
|
||||
|
||||
for field_name in sub_features[0].keys():
|
||||
# Some fields you might want to skip or treat specially:
|
||||
if field_name == "length":
|
||||
continue
|
||||
|
||||
# If it’s a KD field that’s a list-of-lists (e.g. target_logprobs),
|
||||
# you typically just want to flatten them by extending.
|
||||
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
|
||||
combined = []
|
||||
for feat in sub_features:
|
||||
combined.extend(feat[field_name])
|
||||
out_features[i][field_name] = combined
|
||||
|
||||
elif field_name == "attention_mask":
|
||||
# Here we apply the (j+1) factor to differentiate each sub-sample
|
||||
# within this merged batch item.
|
||||
arrays = []
|
||||
for j, feat in enumerate(sub_features):
|
||||
if field_name in feat:
|
||||
arrays.append((j + 1) * np.array(feat[field_name]))
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
else:
|
||||
# By default, just concatenate them if they are arrays
|
||||
# or extend them if they are lists.
|
||||
# For example, input_ids or labels are often arrays.
|
||||
arrays = []
|
||||
for feat in sub_features:
|
||||
if field_name in feat:
|
||||
arr = np.array(feat[field_name])
|
||||
arrays.append(arr)
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
|
||||
# 3) Now call the parent collator, which will do:
|
||||
# - padding of labels/position_ids
|
||||
# - KD-specific padding for target_logprobs, target_token_ids, etc.
|
||||
# - final conversion to return_tensors
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
@@ -1,58 +0,0 @@
|
||||
### AXOLOTL COMMUNITY LICENSE AGREEMENT
|
||||
|
||||
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
|
||||
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
|
||||
and conditions set forth in this Agreement.
|
||||
|
||||
1. Definitions
|
||||
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
|
||||
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
|
||||
which may be licensed separately by their respective authors and/or licensors.
|
||||
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
|
||||
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
|
||||
permits Plugin Integrations to integrate with the Axolotl service.
|
||||
2. Grant of License
|
||||
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
|
||||
- Licensee must comply with all the terms and conditions of this Agreement.
|
||||
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
|
||||
portions of the Software.
|
||||
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
|
||||
3. Restrictions
|
||||
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
|
||||
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
|
||||
third parties to fine-tune artificial intelligence models.
|
||||
3.2 Licensee shall not:
|
||||
- Use the Software for any illegal or unauthorized purpose.
|
||||
- Reverse engineer, decompile, or disassemble the Software.
|
||||
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
|
||||
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
|
||||
Software or interfere with any third-party use of the Software.
|
||||
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
|
||||
4. Intellectual Property Rights
|
||||
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
|
||||
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
|
||||
Licensee.
|
||||
5. Disclaimer of Warranty
|
||||
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
|
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
|
||||
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
6. Termination
|
||||
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
|
||||
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
|
||||
copies in its possession.
|
||||
7. Governing Law
|
||||
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
|
||||
without regards to conflicts of laws provisions thereof.
|
||||
8. Entire Agreement
|
||||
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
|
||||
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
|
||||
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
|
||||
Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms
|
||||
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
|
||||
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
|
||||
bound by the terms and conditions of this Agreement.
|
||||
|
||||
This Agreement was last updated on August 23, 2024.
|
||||
@@ -1,235 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
|
||||
"""
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def zscore_standardize(
|
||||
logits: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
base_temperature: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
):
|
||||
"""
|
||||
Z-score standardize along the last dimension of `logits`.
|
||||
i.e., for each [B, seq_len] row, across K entries:
|
||||
z = (logits - mean) / std,
|
||||
then scale by 1 / base_temperature if desired.
|
||||
|
||||
mask can be broadcastable or None. If None, we standardize all elements.
|
||||
"""
|
||||
if mask is None:
|
||||
# shape: [B, seq_len, K]
|
||||
# Mean and std over dim=-1
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
||||
else:
|
||||
# If you have to exclude some tokens, multiply by mask, etc.
|
||||
float_mask = mask.to(logits.dtype)
|
||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
||||
|
||||
std = torch.sqrt(var.clamp_min(eps))
|
||||
z = (logits - mean) / std
|
||||
|
||||
# Scale by 1 / base_temperature
|
||||
z = z / base_temperature
|
||||
return z
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
|
||||
Arguments:
|
||||
student_logits (torch.Tensor): The logits of the student model.
|
||||
Shape: [B, student_seq_len, vocab_size]
|
||||
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
target_mask (torch.Tensor): The mask for valid tokens.
|
||||
Shape: [B, teacher_seq_len, top_k]
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Determine the teacher sequence length
|
||||
# target_token_ids shape: [B, teacher_seq_len, K]
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, teacher_seq_len, K]
|
||||
else:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||
valid_mask = target_mask.to(torch.bool)
|
||||
|
||||
# Prune tensors to only keep valid tokens
|
||||
student_logprobs_topk = student_logprobs_topk[valid_mask]
|
||||
target_logprobs = target_logprobs[valid_mask]
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = target_logprobs.exp()
|
||||
|
||||
# Compute forward KL
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items (if provided) or by valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
# Fall back to average over valid tokens
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
|
||||
|
||||
def topk_kd_loss_with_zscore(
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
kd_temperature: float = 1.0, # classic KD temperature
|
||||
zscore_base_temp: float = 1.0, # from the paper
|
||||
num_items_in_batch: int = -1,
|
||||
):
|
||||
"""
|
||||
A variant of top_k KL divergence with Z-score scaling
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab]
|
||||
student_topk_logits = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, seq_len, K]
|
||||
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
|
||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||
if kd_temperature != 1.0:
|
||||
student_topk_logits = student_topk_logits / kd_temperature
|
||||
|
||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||
# (They differ by +some_constant from real logits, but in z-score
|
||||
# that constant is subtracted out anyway.)
|
||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||
|
||||
# 4) Z-score teacher and student
|
||||
# If target_mask is 2D, expand to 3D for the K dimension
|
||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
|
||||
teacher_z = zscore_standardize(
|
||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
student_z = zscore_standardize(
|
||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
|
||||
# 5) Convert to log-probs for KL
|
||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||
|
||||
# 6) Restrict to valid tokens if needed
|
||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||
teacher_probs_z = teacher_logprobs_z.exp()
|
||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
||||
|
||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 8) If using classical KD scaling by T^2
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
||||
|
||||
# 9) Normalize
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
@@ -1,113 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
if self._signature_columns:
|
||||
if "target_logprobs" not in self._signature_columns:
|
||||
columns_to_add.append("target_logprobs")
|
||||
if "target_token_ids" not in self._signature_columns:
|
||||
columns_to_add.append("target_token_ids")
|
||||
if "target_mask" not in self._signature_columns:
|
||||
columns_to_add.append("target_mask")
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,44 +0,0 @@
|
||||
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||
|
||||
https://github.com/HazyResearch/lolcats/
|
||||
|
||||
### Usage
|
||||
|
||||
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
||||
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Step 1:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lolcats.LinearizePlugin
|
||||
|
||||
linearize: true
|
||||
```
|
||||
|
||||
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
||||
|
||||
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
||||
|
||||
```yaml
|
||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
|
||||
# with optional config below but this requires patching axolotl
|
||||
# to allow this config to work with lora
|
||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||
```
|
||||
|
||||
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
||||
|
||||
Step 3: Run inference on the finetuned model
|
||||
|
||||
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||
|
||||
Low-rank Linear Conversion via Attention Transfer
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||
DistillAttentionXentMSETrainer,
|
||||
)
|
||||
|
||||
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||
|
||||
|
||||
class LinearizePlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for lolcats integration with Axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Register the Linear Llama model with transformers
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
register_linear_llama,
|
||||
)
|
||||
|
||||
register_linear_llama()
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
# defualt to XentMSE
|
||||
# TODO: add check to allow MSE_linear
|
||||
if cfg.linearize:
|
||||
return DistillAttentionXentMSETrainer
|
||||
|
||||
return None
|
||||
@@ -1,47 +0,0 @@
|
||||
"""
|
||||
Module for handling linear attention input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FeatureMapKwargs(BaseModel):
|
||||
"""Args for feature map"""
|
||||
|
||||
eps: float
|
||||
mlp: Optional[None] = None
|
||||
fullspace: bool
|
||||
|
||||
|
||||
class LearnedKernelKwargs(BaseModel):
|
||||
"""Args for learned kernel"""
|
||||
|
||||
feature_dim: int
|
||||
skip_connection: bool
|
||||
bias: bool
|
||||
zero_init: bool
|
||||
|
||||
|
||||
class AttentionConfig(BaseModel):
|
||||
"""Args for attention config"""
|
||||
|
||||
attention_type: str
|
||||
feature_map: str
|
||||
feature_map_kwargs: FeatureMapKwargs
|
||||
layer_idx: Optional[None] = None
|
||||
learned_kernel: str
|
||||
learned_kernel_kwargs: LearnedKernelKwargs
|
||||
tie_qk_kernels: bool
|
||||
train_qk: bool
|
||||
|
||||
|
||||
class LinearAttentionArgs(BaseModel):
|
||||
"""
|
||||
Input args for linear attention
|
||||
"""
|
||||
|
||||
attention_config: AttentionConfig
|
||||
|
||||
linearize: Optional[bool] = False
|
||||
@@ -1,90 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Linear LLaMA model configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class LinearLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
||||
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
||||
|
||||
Args:
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
Expected contents:
|
||||
`attention_type` (str):
|
||||
The type of attention to convert to.
|
||||
`feature_map` (`str`):
|
||||
The type of feature map to use for linear attention.
|
||||
`feature_map_kwargs` (`dict`):
|
||||
Additional arguments for the feature map.
|
||||
`learned_kernel` (`str`, *optional*):
|
||||
Type of learned kernel to use, if any.
|
||||
`learned_kernel_kwargs` (`dict`, *optional*):
|
||||
Additional arguments for the learned kernel.
|
||||
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
||||
Whether to tie query and key kernels.
|
||||
`rotary_config` (`dict`, *optional*):
|
||||
Configuration for rotary embeddings.
|
||||
`train_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to train attention to match softmax attention.
|
||||
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
||||
Whether to remove base attention after initialization.
|
||||
`mask_value` (`int`, *optional*, defaults to 0):
|
||||
Value to use for masking.
|
||||
`eps` (`float`, *optional*, defaults to 1e-12):
|
||||
Epsilon value for numerical stability.
|
||||
`fp32_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to use fp32 precision for attention computation.
|
||||
`track_state_grads` (`bool`, *optional*, defaults to False):
|
||||
Whether to track gradients of attention states.
|
||||
|
||||
**kwargs:
|
||||
Additional arguments inherited from LlamaConfig.
|
||||
"""
|
||||
|
||||
model_type = "linear_llama"
|
||||
|
||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set auto_map
|
||||
self.auto_map = {
|
||||
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||
}
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||
|
||||
@classmethod
|
||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||
"""
|
||||
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||
|
||||
Args:
|
||||
llama_config (:class:`~transformers.LlamaConfig`):
|
||||
The LlamaConfig to inherit from.
|
||||
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
"""
|
||||
|
||||
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||
@@ -1,30 +0,0 @@
|
||||
# Causal linear attention CUDA kernel
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Reference: https://github.com/idiap/fast-transformers/
|
||||
|
||||
```bib
|
||||
@inproceedings{katharopoulos_et_al_2020,
|
||||
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
||||
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
||||
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
@article{vyas_et_al_2020,
|
||||
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
||||
title={Fast Transformers with Clustered Attention},
|
||||
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
from .causal_attention import causal_dot_product
|
||||
@@ -1,225 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
// Apoorv Vyas <avyas@idiap.ch>
|
||||
//
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
/**
|
||||
* Compute a*b^T and save it into out.
|
||||
*
|
||||
* a \in R^A
|
||||
* b \in R^B
|
||||
*/
|
||||
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float * bi = b;
|
||||
for (int j=0; j<B; j++) {
|
||||
*out += (*a) * (*bi);
|
||||
out++;
|
||||
bi++;
|
||||
}
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector matrix product v*m and save it into out.
|
||||
*
|
||||
* v \in R^A
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
||||
// TODO: Consider removing the zeroing part and assuming out already
|
||||
// contains 0s
|
||||
for (int i=0; i<B; i++) {
|
||||
out[i] = 0;
|
||||
}
|
||||
|
||||
for (int i=0; i<A; i++) {
|
||||
float *oi = out;
|
||||
for (int j=0; j<B; j++) {
|
||||
*oi += (*v) * (*m);
|
||||
oi++;
|
||||
m++;
|
||||
}
|
||||
v++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector transposed-matrix product and save it into out.
|
||||
*
|
||||
* v \in R^B
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float *vi = v;
|
||||
float s = 0;
|
||||
for (int j=0; j<B; j++) {
|
||||
s += (*vi) * (*m);
|
||||
vi++;
|
||||
m++;
|
||||
}
|
||||
// TODO: Should we be aggregating? See the comment on vm_dot.
|
||||
*out = s;
|
||||
out++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the causally masked dot products of queries, keys and values.
|
||||
*
|
||||
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
||||
* computation is done efficiently by changing the order of the dot products.
|
||||
*/
|
||||
void causal_dot_product(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
torch::Tensor product
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto pa = product.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&qa[n][h][l][0],
|
||||
kvp,
|
||||
&pa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the gradients of queries, keys and values given the gradient of the
|
||||
* causal_dot_product output.
|
||||
*
|
||||
* Make sure that everything is computed in O(N D^2) complexity.
|
||||
*/
|
||||
void causal_dot_backward(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
const torch::Tensor grad_out,
|
||||
torch::Tensor grad_queries,
|
||||
torch::Tensor grad_keys,
|
||||
torch::Tensor grad_values
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto ga = grad_out.accessor<float, 4>();
|
||||
auto gqa = grad_queries.accessor<float, 4>();
|
||||
auto gka = grad_keys.accessor<float, 4>();
|
||||
auto gva = grad_values.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
|
||||
// Compute the gradient wrt the queries
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
&gqa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
|
||||
// Compute the gradient wrt the keys and values
|
||||
kv.zero_();
|
||||
for (int l=L-1; l>=0; l--) {
|
||||
vvt_dot(
|
||||
&qa[n][h][l][0],
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
&gka[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&ka[n][h][l][0],
|
||||
kvp,
|
||||
&gva[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"causal_dot_product",
|
||||
&causal_dot_product,
|
||||
"Compute the weighted sum of values but attending only to previous "
|
||||
"values."
|
||||
);
|
||||
m.def(
|
||||
"causal_dot_backward",
|
||||
&causal_dot_backward,
|
||||
"Compute the gradient of queries, keys and values given the gradient "
|
||||
"of causal_dot_product."
|
||||
);
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
||||
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
||||
|
||||
|
||||
class CausalDotProduct(torch.autograd.Function):
|
||||
"""Compute the weighted sum of values but attending only to previous
|
||||
values."""
|
||||
|
||||
dot = {
|
||||
# "cpu": causal_dot_product_cpu,
|
||||
"cuda": causal_dot_product_cuda
|
||||
}
|
||||
dot_backward = {
|
||||
# "cpu": causal_dot_backward_cpu,
|
||||
"cuda": causal_dot_backward_cuda
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, Q, K, V):
|
||||
# Save the inputs for the gradient computation
|
||||
ctx.save_for_backward(Q, K, V)
|
||||
|
||||
# Create the output tensor
|
||||
device = Q.device
|
||||
N, H, L, _ = Q.shape
|
||||
_, _, _, M = V.shape
|
||||
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
||||
|
||||
# Actually perform the dot product
|
||||
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
# breakpoint()
|
||||
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
|
||||
return product
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
# Extract the saved tensors
|
||||
Q, K, V = ctx.saved_tensors
|
||||
|
||||
# Allocate memory for the gradients
|
||||
grad_Q = torch.zeros_like(Q)
|
||||
grad_K = torch.zeros_like(K)
|
||||
grad_V = torch.zeros_like(V)
|
||||
|
||||
# Actually compute the gradients
|
||||
CausalDotProduct.dot_backward[Q.device.type](
|
||||
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
||||
)
|
||||
|
||||
return grad_Q, grad_K, grad_V
|
||||
|
||||
|
||||
# Alias the autograd functions to python style snake case naming
|
||||
causal_dot_product = CausalDotProduct.apply
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,65 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
def get_last_arch_torch():
|
||||
arch = torch.cuda.get_arch_list()[-1]
|
||||
print(f"Found arch: {arch} from existing torch installation")
|
||||
return arch
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
arch = get_last_arch_torch()
|
||||
sm_num = arch[-2:]
|
||||
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
||||
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
||||
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
||||
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
||||
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
||||
|
||||
setup(
|
||||
name="causal_attention_cuda_cpp",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
"causal_attention_cuda",
|
||||
[
|
||||
# 'causal_attention.cpp',
|
||||
"causal_attention_cuda.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
@@ -1,856 +0,0 @@
|
||||
"""
|
||||
Linear attention classes
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
try:
|
||||
from csrc import causal_dot_product as fast_causal_dot_product
|
||||
except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
# -------------------
|
||||
# Attention functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""
|
||||
Causal linear attention dot product
|
||||
- If available, use CUDA kernel from fast-transformers
|
||||
"""
|
||||
if fast_causal_dot_product is None:
|
||||
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
||||
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
||||
return fast_causal_dot_product(q, k, v)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Compute linear attention with CUDA kernel implementation from fast-transformers
|
||||
- https://github.com/idiap/fast-transformers
|
||||
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
||||
v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
dtype = q.dtype
|
||||
# Causal mask already applied
|
||||
y = causal_dot_product(
|
||||
q.contiguous().to(dtype=torch.float32),
|
||||
k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
)
|
||||
if fp32_attention:
|
||||
y = (
|
||||
y
|
||||
/ (
|
||||
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
||||
)[..., None]
|
||||
).to(dtype=dtype)
|
||||
else:
|
||||
y = y.to(dtype=dtype)
|
||||
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
||||
return y, None, None
|
||||
|
||||
|
||||
def softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
"""
|
||||
Standard softmax attention; only compute outputs if v is not None
|
||||
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
y = None
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
||||
if fp32_attention:
|
||||
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
else:
|
||||
a = torch.softmax(a, dim=-1)
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
def quadratic_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""
|
||||
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
||||
-> Use for attention distillation
|
||||
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
y = None
|
||||
dtype = q.dtype
|
||||
if fp32_attention:
|
||||
q, k = q.float(), k.float()
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, 0)
|
||||
# Normalize to compute attention
|
||||
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
||||
a = a.to(dtype=dtype) if fp32_attention else a
|
||||
if torch.isnan(a).sum() > 0:
|
||||
breakpoint()
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
|
||||
|
||||
class LolcatsLinearAttention(nn.Module):
|
||||
"""
|
||||
LoLCATs attention implementation initialized from a
|
||||
`LlamaAttention` or `MistralAttention` object (base_attn)
|
||||
|
||||
Most of the arguments are directly tied to argparse args
|
||||
- For now we don't support padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_attn: nn.Module, # like LlamaAttention
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
layer_idx: Optional[int] = None,
|
||||
max_layer_idx: Optional[int] = None,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
tie_qk_kernels: Optional[bool] = False,
|
||||
rotary_config: Optional[dict] = None,
|
||||
train_attention: Optional[bool] = False,
|
||||
remove_base_attn: bool = True,
|
||||
attention_type: Optional[str] = "lolcats_llama",
|
||||
mask_value: int = 0,
|
||||
eps: float = 1e-12,
|
||||
fp32_attention: bool = False,
|
||||
track_state_grads: bool = False,
|
||||
rank: Optional[int] = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_config = getattr(base_attn, "config", None)
|
||||
if self.base_config is not None:
|
||||
self.base_config = self.base_config.to_dict()
|
||||
self.attention_type = attention_type
|
||||
self.mask_value = mask_value
|
||||
self.eps = eps
|
||||
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
||||
self.max_layer_idx = max_layer_idx
|
||||
self.tie_qk_kernels = tie_qk_kernels
|
||||
self.train_attention = train_attention
|
||||
self.base_inference = False
|
||||
self.fp32_attention = fp32_attention
|
||||
self.track_state_grads = track_state_grads
|
||||
if rank == 0: # multi-gpu
|
||||
if fp32_attention and layer_idx == 0:
|
||||
print(f"-> fp32_attention is {fp32_attention}")
|
||||
if layer_idx == 0 and feature_map_kwargs is not None:
|
||||
for k, v in feature_map_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
||||
for k, v in learned_kernel_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
|
||||
self.remove_base_attn = remove_base_attn
|
||||
|
||||
self.init_weights_(base_attn, remove_base_attn)
|
||||
self.init_feature_map_(
|
||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||
)
|
||||
|
||||
def init_feature_map_(
|
||||
self,
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize MLP-based feature map
|
||||
"""
|
||||
self.fmap_gqa = False # Turn True if specified below
|
||||
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
||||
# Ensure dict
|
||||
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
||||
learned_kernel_kwargs["num_heads"] = self.num_heads
|
||||
learned_kernel_kwargs["head_dim"] = self.head_dim
|
||||
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
||||
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
||||
# Create MLP
|
||||
mlp_learned_kernel = init_learned_kernel(
|
||||
learned_kernel, **learned_kernel_kwargs
|
||||
)
|
||||
# Add "activation"; see src.models.feature_map.py
|
||||
self.feature_map_q = init_feature_map(
|
||||
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
||||
)
|
||||
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
||||
self.feature_map_k = self.feature_map_q
|
||||
else:
|
||||
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
||||
|
||||
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
||||
"""
|
||||
Initialize module layers, weights, positional dependencies, etc.
|
||||
from original softmax attention layer (base_attn)
|
||||
"""
|
||||
# Make other attributes accessible
|
||||
self.attention_dropout = 0 # We don't use dropout
|
||||
self.hidden_size = base_attn.config.hidden_size
|
||||
self.num_heads = base_attn.config.num_attention_heads
|
||||
self.head_dim = base_attn.head_dim
|
||||
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||
|
||||
self.q_shape = [self.num_heads, self.head_dim]
|
||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||
|
||||
# Copy original model projection layers
|
||||
self.q_proj = base_attn.q_proj
|
||||
self.k_proj = base_attn.k_proj
|
||||
self.v_proj = base_attn.v_proj
|
||||
self.o_proj = base_attn.o_proj
|
||||
try: # If wanting to use FA2 for ground-truth inference
|
||||
self._flash_attn_uses_top_left_mask = (
|
||||
base_attn._flash_attn_uses_top_left_mask
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if self.remove_base_attn or remove_base_attn:
|
||||
del base_attn # We don't need to keep these around
|
||||
else:
|
||||
self.base_attn = base_attn # For some training runs helpful to just call
|
||||
|
||||
def process_qkv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Compute queries, keys, and values
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
kv_seq_len = k.shape[-2]
|
||||
|
||||
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
||||
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
||||
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
||||
|
||||
if (
|
||||
past_key_value is not None
|
||||
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
||||
past_key_value.window_size = getattr(
|
||||
self, "decode_window_size", None
|
||||
) # self.decode_window_size
|
||||
if isinstance(
|
||||
past_key_value, Cache
|
||||
): # In Transformers v4.36+ this is a DynamicCache object
|
||||
kv_seq_len += past_key_value.get_usable_length(
|
||||
kv_seq_len, self.layer_idx
|
||||
)
|
||||
else:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
k = repeat_kv(k, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
return q, k, v, kv_seq_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
||||
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||
)
|
||||
|
||||
if self.base_inference:
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = (None, None)
|
||||
|
||||
elif self.train_attention: # Distilling / learning attentions
|
||||
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
||||
assert (
|
||||
output_attentions is True
|
||||
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention (just weights)
|
||||
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
||||
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
||||
attn_weights = ( # type: ignore
|
||||
(attn_pred, attn_true),
|
||||
(y_pred, _y_true),
|
||||
) # Save both attention weights so we can supervise.
|
||||
|
||||
else: # Finetuning
|
||||
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
# Apply prefill mask
|
||||
if attention_mask is not None and q.shape[2] > 1:
|
||||
if len(attention_mask.shape) == 4:
|
||||
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
||||
..., None
|
||||
] # b, 1, k_len, 1
|
||||
else:
|
||||
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||
k = k.masked_fill(~lin_attn_mask, 0)
|
||||
|
||||
if past_key_value is not None: # Initialize states
|
||||
if len(past_key_value.kv_states) == self.layer_idx:
|
||||
b, h, _, f = k.shape
|
||||
past_key_value.kv_states.append(
|
||||
torch.zeros(
|
||||
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
||||
)
|
||||
)
|
||||
past_key_value.k_states.append(
|
||||
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
||||
)
|
||||
# Generating
|
||||
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
||||
assert use_cache is True
|
||||
kv_state, k_state = past_key_value.update(
|
||||
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
||||
)
|
||||
if self.fp32_attention:
|
||||
q = q.float()
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
||||
..., None
|
||||
]
|
||||
).to(dtype=k.dtype)
|
||||
else:
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
||||
)
|
||||
else:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
y_true, _, _ = linear_attention(
|
||||
q, k, v, self.fp32_attention, self.eps
|
||||
) # Ordinarily the states are ignored
|
||||
past_key_value.update(
|
||||
k.detach(),
|
||||
v.detach(),
|
||||
self.layer_idx,
|
||||
accumulate_in_fp32=self.fp32_attention,
|
||||
)
|
||||
# doing some unnecessary recomputation here
|
||||
else:
|
||||
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = None
|
||||
|
||||
return y_true, attn_weights
|
||||
|
||||
|
||||
class LinearAttentionState(Cache):
|
||||
"""
|
||||
Handle the KV and K states for linear attention
|
||||
- Adopts HF Transformers `past_key_values` convention
|
||||
- Inherits from `Cache` class
|
||||
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""
|
||||
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
||||
"""
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
||||
self._seen_tokens_by_layer.append(0)
|
||||
return self._seen_tokens_by_layer[layer_idx]
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""
|
||||
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_usable_length(
|
||||
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||
) -> int:
|
||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||
# Cache without size limit -> all cache is usable
|
||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||
max_length = self.get_max_length()
|
||||
previous_seq_length = self.get_seq_length(layer_idx)
|
||||
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||
return max_length - new_seq_length
|
||||
return previous_seq_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.no_grad():
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
key_states, value_states = key_states.float(), value_states.float()
|
||||
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", key_states, value_states
|
||||
).detach()
|
||||
k_state = key_states.sum(
|
||||
dim=-2, keepdim=True
|
||||
).detach() # b, h, 1, f; note the 1
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
print(
|
||||
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
||||
)
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
else:
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def to_legacy_cache(self):
|
||||
"""Hack, but just return self"""
|
||||
return self
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""
|
||||
Reorders the cache for beam search, given the selected beam indices.
|
||||
-> Copied from transformers/src/transformers/cache_utils.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Reordering cache not implemented for LinearAttentionState"
|
||||
)
|
||||
|
||||
|
||||
# -------------------
|
||||
# feature map functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||
|
||||
|
||||
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
if name == "softmax_dim" and fullspace:
|
||||
return SoftmaxDim(**kwargs)
|
||||
elif name == "softmax_dim" and not fullspace:
|
||||
return SoftmaxDimHalfspace(**kwargs)
|
||||
elif name == "exp_dim" and fullspace:
|
||||
return Exp(**kwargs)
|
||||
elif name == "exp_dim" and not fullspace:
|
||||
return ExpHalfspace(**kwargs)
|
||||
elif name == "pos_elu":
|
||||
return PosELU(**kwargs)
|
||||
elif name == "relu":
|
||||
return ReLU(**kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def init_learned_kernel(name: str, **kwargs):
|
||||
"""
|
||||
Initialize feature map MLP for linear attention
|
||||
"""
|
||||
if name == "untied_head_einsum":
|
||||
return FeatureMapMLP(**kwargs)
|
||||
elif name == "untied_head_adapter":
|
||||
return FeatureMapAdapter(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FeatureMap(nn.Module):
|
||||
"""
|
||||
Final 'activation' of feature map. Can probably be combined with
|
||||
`FeatureMapMLP` below
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `f` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_name: str,
|
||||
head_dim_idx: int = -1,
|
||||
eps: float = 1e-12,
|
||||
mlp: Optional[nn.Module] = None,
|
||||
fullspace: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim_idx = head_dim_idx
|
||||
self.eps = eps
|
||||
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||
"""
|
||||
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||
|
||||
def q_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def k_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Feature map activations
|
||||
# -----------------------
|
||||
class FeatureMapAct(nn.Module):
|
||||
"""
|
||||
Base class for feature map activations
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class PosELU(FeatureMapAct):
|
||||
"""
|
||||
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||
|
||||
|
||||
class ReLU(FeatureMapAct):
|
||||
"""
|
||||
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return F.relu(x).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDim(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.cat(
|
||||
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||
).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||
|
||||
|
||||
class Exp(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||
min=self.eps
|
||||
)
|
||||
|
||||
|
||||
class ExpHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||
|
||||
|
||||
# ----------------
|
||||
# Feature map MLPs
|
||||
# ----------------
|
||||
|
||||
|
||||
class FeatureMapMLP(nn.Module):
|
||||
"""
|
||||
Learnable MLP in feature map.
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `W` and (optional) `b` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int, # input dim
|
||||
feature_dim: int, # output dim
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
skip_connection: bool = False,
|
||||
bias: bool = False,
|
||||
zero_init: bool = False,
|
||||
normal_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.feature_dim = feature_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.skip_connection = skip_connection
|
||||
self.bias = bias
|
||||
self.zero_init = zero_init
|
||||
self.normal_init = normal_init
|
||||
self.init_weights_()
|
||||
|
||||
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||
|
||||
if self.normal_init:
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.layer)
|
||||
|
||||
if self.skip_connection:
|
||||
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||
assert self.head_dim == self.feature_dim, assertion_fail
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
self.layer = nn.Parameter(
|
||||
torch.zeros(
|
||||
(self.num_heads, self.head_dim, self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer)
|
||||
|
||||
if self.bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias)
|
||||
else:
|
||||
self.bias = 0.0 # hack
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
"""
|
||||
Initialize weights to zero matrix if skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer)
|
||||
|
||||
def zero_init_(self):
|
||||
"""
|
||||
Initialize weights to identity matrix if no skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for i in range(self.layer.shape[0]):
|
||||
try:
|
||||
nn.init.eye_(self.layer[i])
|
||||
except RuntimeError:
|
||||
with torch.no_grad():
|
||||
dtype = self.layer[i].dtype
|
||||
weight = torch.eye(
|
||||
*self.layer[i].shape,
|
||||
requires_grad=self.layer[i].requires_grad,
|
||||
device=self.layer[i].device,
|
||||
)
|
||||
self.layer[i] = weight.to(dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||
return x + _x if self.skip_connection else _x
|
||||
|
||||
|
||||
class FeatureMapAdapter(FeatureMapMLP):
|
||||
"""
|
||||
Learnable Feature map with bottleneck adapter
|
||||
as in https://arxiv.org/abs/1902.00751
|
||||
|
||||
We don't use but could be fun to try
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||
kwargs["skip_connection"] = True
|
||||
kwargs["bias"] = True
|
||||
kwargs["zero_init"] = True
|
||||
self.hidden_dim = hidden_dim
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||
self.layer0 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.layer1 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer0)
|
||||
nn.init.kaiming_uniform_(self.layer1)
|
||||
|
||||
self.bias0 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.bias1 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias0)
|
||||
nn.init.kaiming_uniform_(self.bias1)
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer0)
|
||||
nn.init.zeros_(self.layer1)
|
||||
nn.init.zeros_(self.bias0)
|
||||
nn.init.zeros_(self.bias1)
|
||||
|
||||
def zero_init_(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||
"""
|
||||
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||
_x = F.relu(_x)
|
||||
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||
return x + _x if self.skip_connection else _x
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_sw'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,685 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
causal_dot_product,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0)
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0) - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Hybrid window attention linear
|
||||
# ------------------------------
|
||||
def under_window_linear_attention(
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
linear_factor: torch.Tensor,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
||||
dtype = f_q.dtype
|
||||
w = window_size
|
||||
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
qkv = linear_factor * causal_dot_product(
|
||||
f_q.contiguous().to(dtype=torch.float32),
|
||||
f_k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
).to(dtype=dtype)
|
||||
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
||||
sum_qk[sum_qk == 0] += eps
|
||||
return qkv, sum_qk
|
||||
|
||||
|
||||
def sliding_window_softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
window_factor: torch.Tensor,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Compute sliding window softmax attention without materializing
|
||||
O(seq_len^2) attention weights
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
# Compute windows for keys
|
||||
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
|
||||
# Compute windowed_softmax(qk); causal in its construction
|
||||
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
||||
a_sm[a_sm == 0] = -torch.finfo(
|
||||
q.dtype
|
||||
).max # heuristic for zeroing out padding above
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
||||
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
||||
|
||||
|
||||
def hybrid_attention_linear(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: Optional[torch.Tensor] = None,
|
||||
linear_factor: Optional[torch.Tensor] = None,
|
||||
window_size: int = 64,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Alternative hybrid attention combining sliding window and linear attentions
|
||||
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
||||
"""
|
||||
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
if window_factor is None:
|
||||
raise ValueError("window_factor must be provided")
|
||||
|
||||
if linear_factor is None:
|
||||
raise ValueError("linear_factor must be provided")
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
with torch.no_grad():
|
||||
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
||||
q, k, v, window_size, window_factor, mask_value
|
||||
)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
||||
f_q, f_k, v, window_size, linear_factor, eps
|
||||
)
|
||||
|
||||
# 3. Combine
|
||||
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
||||
return y, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
# Determine how we compute attentions
|
||||
self.linear_attention = hybrid_attention_linear
|
||||
self.attention_type = "lolcats_llama_window_sw"
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
|
||||
if self.train_attention and self.base_inference:
|
||||
with torch.no_grad():
|
||||
_y_true = flash_attention_2(
|
||||
self, # self.base_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)[0]
|
||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||
y_true = self.o_proj(y_true)
|
||||
# layer_io = (hidden_states, _y_true) # hack
|
||||
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||
return y_true, layer_io, None
|
||||
|
||||
else:
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
_y_true = y_true.transpose(1, 2).contiguous()
|
||||
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
||||
|
||||
if self.train_attention:
|
||||
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Flash Attention 2
|
||||
# -----------------
|
||||
|
||||
|
||||
def flash_attention_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Wrapper for LlamaFlashAttention2
|
||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||
"""
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
except Exception: # As in Transformers v4.39
|
||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
LOG.debug(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self, "_flash_attention_forward", False):
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=0, # dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
return attn_output, past_key_value
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
LoLCATs attention combining sliding window and linear attentions
|
||||
- Using standard sliding window arrangement
|
||||
- Training over long sequences with fixed memory with recurrent view
|
||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
from .linear_window_attention_sw import hybrid_attention_quadratic
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
|
||||
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(self, remove_base_attn=True, **kwargs):
|
||||
# keep self.base_attn for Flash Attention inference
|
||||
super().__init__(remove_base_attn=True, **kwargs)
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
@@ -1,466 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using the TK "terracing" arrangement
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
win_len = window_size
|
||||
m = math.ceil(max(q_len, k_len) / window_size)
|
||||
# Creates an n x n mask where n = window_size^2
|
||||
mask = torch.block_diag(
|
||||
*[
|
||||
torch.ones(
|
||||
(win_len, win_len),
|
||||
)
|
||||
]
|
||||
* m
|
||||
)
|
||||
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
||||
if mask.shape[0] > q_len:
|
||||
mask = mask[-q_len:]
|
||||
if mask.shape[1] > k_len:
|
||||
mask = mask[:, -k_len:]
|
||||
# Return softmax mask (window), linear attention mask
|
||||
mask = mask[None, None, ...] # b, h, q_len, k_len
|
||||
return (
|
||||
torch.tril(mask).to(device=device, dtype=torch.int),
|
||||
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_tk'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
self.window_factor = self.window_factors # legacy naming support
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionTKWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states should not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("layer_idx should not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
LoLCATs + ThunderKittens linear attention + sliding window for generation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear_attention import LinearAttentionState
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||
|
||||
LOG.debug("Successfully imported ThunderKittens for TK window attention")
|
||||
except ImportError:
|
||||
LOG.debug("Failed to import ThunderKittens for TK window attention")
|
||||
|
||||
|
||||
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
|
||||
def __init__(self, *args, window_size: int = 64, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.train_attention = False
|
||||
self.base_inference = False
|
||||
self.window_size = 64 # hard-coded support for TK kernel
|
||||
self.decode_window_size = 64
|
||||
|
||||
b, h, l, d = 1, 32, 8192, 128
|
||||
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
|
||||
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
|
||||
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Any] = None, # “legacy” cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
assert (
|
||||
past_key_value is not None
|
||||
), "past_key_value must be provided for generation"
|
||||
assert (
|
||||
self.train_attention is False
|
||||
), "train_attention is not supported for generation"
|
||||
assert (
|
||||
self.base_inference is False
|
||||
), "base_inference is not supported for generation"
|
||||
assert use_cache is True, "use_cache must be True for generation"
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
|
||||
f_q = self.feature_map_q(q)
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k
|
||||
)
|
||||
k_cache, v_cache, kv_state, k_state = _kv
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||
k.shape[-1] ** -0.5
|
||||
)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
|
||||
..., None
|
||||
]
|
||||
)
|
||||
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Process prefill
|
||||
# Use TK-implemented linear + terrace window attention
|
||||
b, h, l, d = q.shape
|
||||
device = q.device
|
||||
# tk.hedgehog arguments
|
||||
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
|
||||
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
|
||||
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
|
||||
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
|
||||
alphas = (
|
||||
1 - betas
|
||||
if self.affine_attention_factors
|
||||
else torch.ones(betas.shape, dtype=torch.float32, device=device)
|
||||
)
|
||||
q_map = self.feature_map_q.mlp.layer
|
||||
k_map = self.feature_map_k.mlp.layer
|
||||
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
|
||||
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
|
||||
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
|
||||
# f_k[:, :, :-self.window_size],
|
||||
# v[:, :, :-self.window_size]) # b, h, f, d
|
||||
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
|
||||
|
||||
tk_window_hedgehog_attention(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
self.y_true,
|
||||
self.k_state,
|
||||
self.kv_state,
|
||||
q_map,
|
||||
k_map,
|
||||
alphas,
|
||||
betas,
|
||||
)
|
||||
|
||||
past_key_value.update_with_kv(
|
||||
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
|
||||
)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, None, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.window_size = window_size
|
||||
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
|
||||
def update_with_kv(
|
||||
self,
|
||||
kv_state: torch.Tensor,
|
||||
k_state: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
"""
|
||||
Update the cache with new KV and K states
|
||||
"""
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += k.shape[2]
|
||||
self._seen_tokens_by_layer.append(k.shape[2])
|
||||
|
||||
# Initialize KV and K states
|
||||
if len(self.decode_k_states) <= layer_idx:
|
||||
self.decode_kv_states.append(kv_state)
|
||||
self.decode_k_states.append(k_state)
|
||||
else: # Update KV and K states
|
||||
self.decode_kv_states[layer_idx] = (
|
||||
self.decode_kv_states[layer_idx] + kv_state
|
||||
)
|
||||
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
|
||||
|
||||
self.k_cache.append(k[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(v[:, :, -self.window_size :, :])
|
||||
|
||||
def update_for_decoding(
|
||||
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
|
||||
):
|
||||
"""
|
||||
Update the cache for decoding
|
||||
"""
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
|
||||
k.dtype
|
||||
)
|
||||
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += k.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,306 +0,0 @@
|
||||
"""
|
||||
LoLCATs attention combining sliding window and linear attentions
|
||||
- Using the TK "terracing" arrangement
|
||||
- Training over long sequences with fixed memory with recurrent view
|
||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from .linear_attention import softmax_attention
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||
)
|
||||
|
||||
|
||||
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(self, remove_base_attn=True, **kwargs):
|
||||
# keep self.base_attn for Flash Attention inference
|
||||
super().__init__(remove_base_attn=True, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
if self.train_attention and self.base_inference:
|
||||
with torch.no_grad():
|
||||
# LOG.debug(hidden_states.shape)
|
||||
_y_true = flash_attention_2(
|
||||
self, # self.base_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
# output_hidden_states=False,
|
||||
use_cache=False,
|
||||
)[0]
|
||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||
y_true = self.o_proj(y_true)
|
||||
layer_io = (hidden_states, _y_true) # hack
|
||||
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||
return y_true, layer_io, None
|
||||
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||
k.shape[-1] ** -0.5
|
||||
)
|
||||
# a_sm = torch.softmax(a_sm, dim=-1)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
y_pred = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
|
||||
..., None
|
||||
]
|
||||
)
|
||||
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
if (
|
||||
self.state_grad_enabled
|
||||
and self.layer_idx == 0
|
||||
and position_ids is not None
|
||||
):
|
||||
LOG.debug(
|
||||
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
|
||||
)
|
||||
LOG.debug(
|
||||
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
|
||||
)
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
|
||||
)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
_y_pred = y_pred.transpose(1, 2).contiguous()
|
||||
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
|
||||
|
||||
if self.train_attention:
|
||||
with torch.no_grad():
|
||||
a_true = softmax_attention(q, k, None, causal=True)[1]
|
||||
attn_weights = (_y_pred, (a_pred, a_true))
|
||||
else:
|
||||
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
|
||||
return y_pred, attn_weights, past_key_value
|
||||
|
||||
|
||||
# -----------------
|
||||
# Flash Attention 2
|
||||
# -----------------
|
||||
|
||||
|
||||
def flash_attention_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Wrapper for LlamaFlashAttention2
|
||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
except Exception: # As in Transformers v4.39
|
||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
LOG.debug(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self, "_flash_attention_forward", False):
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=0, # dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
return attn_output, past_key_value
|
||||
@@ -1,361 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
"""Linear LLaMA model implementation."""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from .configuration_linear_llama import LinearLlamaConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
# Replace the attention layer with our custom attention
|
||||
self.self_attn = convert_llama_attention(
|
||||
layer=self, attention_config=config.attention_config
|
||||
)
|
||||
|
||||
|
||||
class LinearLlamaModel(LlamaModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LinearLlamaConfig
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LinearLlamaDecoderLayer(config, layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Linear LLaMA model for causal language modeling.
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LinearLlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: LlamaForCausalLM,
|
||||
config: LinearLlamaConfig,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
) -> "LinearLlamaForCausalLM":
|
||||
"""
|
||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||
"""
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing config")
|
||||
|
||||
# initialize a new model with config
|
||||
new_model = cls(config=config)
|
||||
|
||||
# remove the default model and lm_head
|
||||
del new_model.model
|
||||
del new_model.lm_head
|
||||
|
||||
# load converted model, lm_head, and vocab_size from llama model
|
||||
new_model.model = convert_attention(
|
||||
model.model,
|
||||
attention_config=config.attention_config,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
new_model.lm_head = model.lm_head
|
||||
new_model.vocab_size = model.vocab_size
|
||||
|
||||
return new_model
|
||||
|
||||
def toggle_attention(self, train: bool = True):
|
||||
"""
|
||||
Toggle attention to be trainable or not
|
||||
"""
|
||||
|
||||
toggle_attention(self.model, train=train)
|
||||
|
||||
def remove_base_attention(self):
|
||||
"""
|
||||
Remove base attention after distillation
|
||||
"""
|
||||
|
||||
remove_base_attention(self.model)
|
||||
|
||||
|
||||
def convert_attention(
|
||||
model: nn.Module,
|
||||
attention_config: dict,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Call to convert all attention layers
|
||||
"""
|
||||
# Get the layers to convert if provided
|
||||
softmax_attns = attention_config.get("softmax_attentions", [])
|
||||
|
||||
# Get the attention to convert to
|
||||
attention_type = attention_config.get("attention_type")
|
||||
|
||||
if attention_type != "softmax":
|
||||
layers = traverse_layers(model)
|
||||
for layer_idx, layer in enumerate(
|
||||
tqdm(layers, desc="Converting attentions...")
|
||||
):
|
||||
if layer_idx not in softmax_attns:
|
||||
layer.self_attn = convert_llama_attention(
|
||||
layer,
|
||||
attention_config,
|
||||
layers,
|
||||
train_attention,
|
||||
remove_base_attn,
|
||||
)
|
||||
layer.self_attn.converted = True
|
||||
else:
|
||||
# Freeze any preserved softmax attention layers
|
||||
for p in layer.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
LOG.info(
|
||||
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||
"""
|
||||
Make attentions trainable if train is True
|
||||
-> Set train_attention = False when finetuning
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
layer.self_attn.train_attention = train
|
||||
return llama_model
|
||||
|
||||
|
||||
def remove_base_attention(llama_model: nn.Module):
|
||||
"""
|
||||
Remove teacher attention after distillation (if we keep it)
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
if getattr(layer.self_attn, "base_attn", False):
|
||||
del layer.self_attn.base_attn
|
||||
return llama_model
|
||||
|
||||
|
||||
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||
"""
|
||||
Return list of model layers
|
||||
"""
|
||||
try:
|
||||
layers = model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.model.layers")
|
||||
except AttributeError as e: # if base model
|
||||
if verbose:
|
||||
LOG.info(e)
|
||||
try:
|
||||
layers = model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.layers")
|
||||
except AttributeError as e1: # If we make a PEFT model
|
||||
if verbose:
|
||||
LOG.info(e1)
|
||||
layers = model.base_model.model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||
return layers
|
||||
|
||||
|
||||
def convert_llama_attention(
|
||||
layer: nn.Module,
|
||||
attention_config: dict,
|
||||
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a single layer's attention layer as specified by attention_config
|
||||
"""
|
||||
return get_attention(**attention_config)(
|
||||
base_attn=layer.self_attn,
|
||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||
max_layer_idx=len(layers) - 1 if layers else None,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
|
||||
def get_attention(attention_type: str, **kwargs):
|
||||
"""
|
||||
Get the linear attention class; either purely linear or linear with sliding window
|
||||
-> 'linear' == 'lolcats_llama'
|
||||
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||
"""
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .linear_attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .linear_window_attention_sw_linear import (
|
||||
LolcatsLinearSlidingWindowAttention,
|
||||
)
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
else:
|
||||
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||
return None
|
||||
|
||||
|
||||
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
"""
|
||||
Determine how we store past keys and values when generating
|
||||
"""
|
||||
if attention_type is None:
|
||||
return past_key_values
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "softmax" in attention_type:
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .linear_attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
|
||||
|
||||
def register_linear_llama():
|
||||
"""
|
||||
Register Linear LLaMA model with the Transformers library.
|
||||
"""
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
# registering for auto classes to save files
|
||||
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
@@ -1,118 +0,0 @@
|
||||
"""
|
||||
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
||||
|
||||
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor, nn, tensor
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer class for distilling attentions.
|
||||
- We compute and store the attention outputs and/or weights for each head and layer,
|
||||
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
||||
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
mse_factor: float = 1e3,
|
||||
xent_factor: float = 0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||
self.mse_factor = mse_factor
|
||||
self.xent_factor = xent_factor
|
||||
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||
|
||||
self.model_accepts_loss_kwargs = False # added to combat explosive loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Attention distillation ("attention transfer")
|
||||
- For each layer and head, get attentions and train to
|
||||
minimize some combo of MSE and cross-entropy loss
|
||||
"""
|
||||
# alias inputs to data
|
||||
data = inputs
|
||||
|
||||
device = model.device
|
||||
|
||||
# Filter out labels
|
||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||
|
||||
# set num_items_in_batch
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||
outputs = outputs.get("attentions")
|
||||
|
||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||
# n_layers x (predicted_attns, true_attns)
|
||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||
loss_mse = tensor(0.0, device=device)
|
||||
loss_xent = tensor(0.0, device=device)
|
||||
n_layers = 0 # Number of layers to distill
|
||||
softmax_layers = []
|
||||
for layer_idx, attns in enumerate(outputs):
|
||||
if attns is not None:
|
||||
if len(attns) != 2:
|
||||
attns = attns.cpu()
|
||||
else:
|
||||
if self.xent_factor > 0:
|
||||
# Cross-entropy loss
|
||||
a_pred, a_true = attns[0]
|
||||
a_pred = a_pred.clamp(
|
||||
min=1e-12
|
||||
).log() # nn.CrossEntropy assumes unnormalized logits
|
||||
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
||||
# Compute mean cross-entropy over all queries
|
||||
a_pred = a_pred.contiguous().view(-1, k_len)
|
||||
a_true = a_true.contiguous().view(-1, k_len)
|
||||
loss_xent += self.criterion_xent(a_pred, a_true)
|
||||
if self.mse_factor > 0:
|
||||
loss_mse += self.criterion_mse(*attns[1])
|
||||
n_layers += 1
|
||||
else:
|
||||
softmax_layers.append(layer_idx)
|
||||
if n_layers > 0:
|
||||
loss_xent = loss_xent / n_layers * self.xent_factor
|
||||
loss_mse = loss_mse / n_layers * self.mse_factor
|
||||
loss = loss_xent + loss_mse
|
||||
|
||||
if "position_ids" in data:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"input_len": data["position_ids"].shape[1],
|
||||
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
@@ -16,21 +16,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
|
||||
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||
load_fn = "load"
|
||||
package = "axolotl.prompt_strategies"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
elif len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
"." + strategy.split(".")[-1],
|
||||
".".join(strategy.split(".")[:-1]),
|
||||
)
|
||||
package = ".".join(strategy.split(".")[:-1])
|
||||
strategy = strategy.split(".")[-1]
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
mod = importlib.import_module(f".{strategy}", package)
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
|
||||
@@ -10,8 +10,6 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
def load(strategy, cfg, module_base=None, **kwargs):
|
||||
try:
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
|
||||
@@ -21,11 +21,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||
"""
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
return False
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
def tokenize_prompt(self, prompt):
|
||||
"""
|
||||
|
||||
:param prompt: the actual row of data from the underlying dataset
|
||||
@@ -43,7 +39,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
@@ -66,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
|
||||
@@ -3,7 +3,6 @@ HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from transformers import ProcessorMixin
|
||||
@@ -194,7 +193,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: ChatTemplatePrompter,
|
||||
prompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
@@ -221,61 +220,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
# Let calling code know we can handle lists of examples
|
||||
return True
|
||||
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def tokenize_prompt(self, prompt: dict[str, Any]):
|
||||
"""
|
||||
Public method that can handle either a single prompt or a batch of prompts.
|
||||
"""
|
||||
|
||||
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
||||
return self._tokenize_single_prompt(prompt)
|
||||
|
||||
res = defaultdict(lambda: [])
|
||||
feature_names = list(prompt.keys())
|
||||
|
||||
# Process each prompt individually
|
||||
for row in zip(*prompt.values()):
|
||||
tokenized_prompt = self._tokenize_single_prompt(
|
||||
dict(zip(feature_names, row))
|
||||
)
|
||||
for key, val in tokenized_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
|
||||
# If there are no examples left, return an empty dictionary
|
||||
if not res:
|
||||
return {}
|
||||
|
||||
return dict(res)
|
||||
|
||||
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Old simple legacy behavior that works reliably.
|
||||
if (
|
||||
not self.roles_to_train
|
||||
and not self.train_on_eos
|
||||
and not self.prompter.message_field_training # type: ignore
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
and not self.prompter.message_field_training
|
||||
and not self.prompter.message_field_training_detail
|
||||
):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
images = self.get_images(prompt)
|
||||
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||
prompt_ids = self.prompter.build_prompt(
|
||||
turns[:-1],
|
||||
add_generation_prompt=True,
|
||||
images=images,
|
||||
)
|
||||
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
|
||||
tokenized_res = self.prompter.build_prompt(turns, images=images)
|
||||
tokenized_prompt = {}
|
||||
if isinstance(tokenized_res, list):
|
||||
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
||||
@@ -296,7 +256,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
||||
input_ids = self.prompter.build_prompt(turns)
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
@@ -326,7 +286,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||
if train_detail:
|
||||
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
||||
token_offsets = self.prompter.get_offsets_for_train_detail(
|
||||
content, train_detail
|
||||
)
|
||||
LOG.debug(f"Token offsets: {token_offsets}")
|
||||
@@ -499,62 +459,43 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
Load chat template strategy based on configuration.
|
||||
"""
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategy
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
return {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
ChatTemplatePrompter(**prompter_params),
|
||||
tokenizer=tokenizer,
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
load = StrategyLoader()
|
||||
return strategy
|
||||
|
||||
@@ -3,41 +3,22 @@ DPO strategies for chatml
|
||||
"""
|
||||
|
||||
|
||||
def default(
|
||||
def argilla(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -3,42 +3,22 @@ DPO strategies for llama-3 chat template
|
||||
"""
|
||||
|
||||
|
||||
def default(
|
||||
def argilla(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(sample):
|
||||
# pylint: disable=duplicate-code
|
||||
if "prompt" in sample.keys():
|
||||
prompt_key = "prompt"
|
||||
elif "input" in sample.keys():
|
||||
prompt_key = "input"
|
||||
elif "question" in sample.keys():
|
||||
prompt_key = "question"
|
||||
else:
|
||||
prompt_key = "instruction"
|
||||
|
||||
if "chosen" in sample.keys():
|
||||
chosen_key = "chosen"
|
||||
else:
|
||||
chosen_key = "chosen_response"
|
||||
|
||||
if "rejected" in sample.keys():
|
||||
rejected_key = "rejected"
|
||||
else:
|
||||
rejected_key = "rejected_response"
|
||||
|
||||
if "system" in sample and sample["system"]:
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
@@ -34,8 +34,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
Abstract class for tokenizing strategies
|
||||
"""
|
||||
|
||||
filter_rows: Optional[Callable] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: Prompter,
|
||||
|
||||
@@ -846,12 +846,6 @@ class GCCallback(TrainerCallback):
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||
if state.global_step % self.gc_steps == 0:
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def on_epoch_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module for working with config dicts"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -130,18 +129,10 @@ def normalize_config(cfg):
|
||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||
if save_steps < 1.0: # prevent saves on every step
|
||||
cfg.save_steps = save_steps
|
||||
elif save_steps > 1:
|
||||
LOG.warning(
|
||||
f"Invalid value for save_steps ({save_steps}) from saves_per_epoch and/or num_epochs. Saving at training end only."
|
||||
)
|
||||
if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
|
||||
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||
if eval_steps < 1.0: # prevent evals on every step
|
||||
cfg.eval_steps = eval_steps
|
||||
elif eval_steps > 1:
|
||||
LOG.warning(
|
||||
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
|
||||
)
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
|
||||
@@ -163,7 +163,6 @@ class SFTDataset(BaseModel):
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
input_transform: Optional[str] = None
|
||||
shards: Optional[int] = None
|
||||
preprocess_shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: Optional[
|
||||
@@ -186,8 +185,6 @@ class SFTDataset(BaseModel):
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
@@ -864,7 +861,6 @@ class AxolotlInputConfig(
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
is_preprocess: Optional[bool] = None
|
||||
preprocess_iterable: Optional[bool] = None
|
||||
|
||||
total_num_tokens: Optional[int] = None
|
||||
total_supervised_tokens: Optional[int] = None
|
||||
|
||||
@@ -3,12 +3,11 @@
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
Sequence,
|
||||
Value,
|
||||
concatenate_datasets,
|
||||
@@ -18,7 +17,7 @@ from datasets import (
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -60,7 +59,7 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_local_main_process()):
|
||||
@@ -71,7 +70,6 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="train",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
_, eval_dataset, _ = load_prepare_datasets(
|
||||
tokenizer,
|
||||
@@ -79,7 +77,6 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
@@ -87,7 +84,6 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
cfg,
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
else:
|
||||
# Load streaming dataset if pretraining_dataset is given
|
||||
@@ -143,7 +139,6 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
DEFAULT_DATASET_PREPARED_PATH,
|
||||
split="test",
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
@@ -175,7 +170,6 @@ def load_tokenized_prepared_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
preprocess_iterable: Optional[bool] = None,
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
@@ -190,11 +184,10 @@ def load_tokenized_prepared_datasets(
|
||||
+ "@"
|
||||
+ str(cfg.group_by_length)
|
||||
+ "@"
|
||||
+ str(cfg.kd_temperature or 1.0)
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}"
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||
for d in cfg_datasets
|
||||
]
|
||||
)
|
||||
@@ -269,25 +262,13 @@ def load_tokenized_prepared_datasets(
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
streaming_ds = False
|
||||
if preprocess_iterable:
|
||||
streaming_ds = True
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=streaming_ds
|
||||
config_dataset, use_auth_token
|
||||
)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
@@ -344,21 +325,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
if isinstance(dataset, IterableDataset):
|
||||
|
||||
def gen_from_iter_ds(_ds, _=None):
|
||||
yield from _ds
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(gen_from_iter_ds, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=cfg.dataset_processes,
|
||||
split=split,
|
||||
gen_kwargs={"_": list(range(cfg.dataset_processes))},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
||||
@@ -378,7 +345,6 @@ def load_prepare_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split="train",
|
||||
processor=None,
|
||||
preprocess_iterable: Optional[bool] = False,
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
tokenizer,
|
||||
@@ -386,7 +352,6 @@ def load_prepare_datasets(
|
||||
default_dataset_prepared_path,
|
||||
split=split,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
@@ -486,7 +451,7 @@ def get_dataset_wrapper(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -499,7 +464,7 @@ def get_dataset_wrapper(
|
||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -522,7 +487,7 @@ def get_dataset_wrapper(
|
||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||
else:
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -535,7 +500,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -549,7 +514,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -563,7 +528,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -577,7 +542,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -591,7 +556,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -605,7 +570,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -619,7 +584,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -633,7 +598,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
|
||||
@@ -29,9 +29,7 @@ def get_ds_type(config_dataset: DictDefault):
|
||||
return ds_type
|
||||
|
||||
|
||||
def load_dataset_w_config(
|
||||
config_dataset, auth_token, streaming=False
|
||||
) -> Union[Dataset, DatasetDict]:
|
||||
def load_dataset_w_config(config_dataset, auth_token):
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
@@ -126,7 +124,7 @@ def load_dataset_w_config(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -159,7 +157,7 @@ def load_dataset_w_config(
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
token=auth_token,
|
||||
revision=config_dataset.revision,
|
||||
@@ -178,7 +176,7 @@ def load_dataset_w_config(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
@@ -189,7 +187,7 @@ def load_dataset_w_config(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
@@ -219,7 +217,7 @@ def load_dataset_w_config(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=streaming,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
|
||||
@@ -10,7 +10,7 @@ from accelerate.utils.environment import get_gpu_info
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada", "L40S"}
|
||||
unsupported_devices = {"RTX 6000 Ada"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
|
||||
@@ -4,6 +4,7 @@ Multipack Batch Sampler
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -116,7 +117,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -133,9 +133,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
|
||||
self.num_count_samples = num_count_samples
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
@@ -172,9 +169,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.len_across_ranks:
|
||||
# make sure the batches we iterate over is truncated to the same min length across all ranks
|
||||
batches = batches[: self.len_across_ranks]
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
@@ -201,15 +195,42 @@ class MultipackBatchSampler(BatchSampler):
|
||||
def gather_len_batches(self, num):
|
||||
def calc_min_len(estimates: list[(int, float)]):
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(min(estimates))
|
||||
return math.floor(0.998 * min(estimates))
|
||||
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
if not self.len_across_ranks:
|
||||
len_batches = min(
|
||||
[self.num_batches() for _ in range(self.num_count_samples)]
|
||||
)
|
||||
len_batches = self.num_batches()
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
return self.len_across_ranks
|
||||
|
||||
def _len_est(self):
|
||||
efficiency = (
|
||||
self.packing_efficiency_estimate
|
||||
if self.packing_efficiency_estimate
|
||||
else self.gather_efficiency()
|
||||
)
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {efficiency} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return max(
|
||||
0,
|
||||
(
|
||||
world_size
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ efficiency
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
)
|
||||
- 1
|
||||
),
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
# Get the input_ids, labels, and attention_mask from the dataset
|
||||
input_ids = example["input_ids"]
|
||||
labels = example["labels"]
|
||||
target_mask = example.pop("target_mask", None)
|
||||
|
||||
# You can compare the input_ids and labels element-wise
|
||||
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
|
||||
@@ -43,13 +42,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
delimiter = "" if text_only else " "
|
||||
LOG.info(delimiter.join(colored_tokens))
|
||||
LOG.info("\n\n\n")
|
||||
target_labels_count = sum(label_id != -100 for label_id in labels)
|
||||
total_len = len(input_ids)
|
||||
LOG.info(f"Total input len: {total_len}")
|
||||
LOG.info(f"Count of labels: {target_labels_count}")
|
||||
if target_mask:
|
||||
target_mask_positions = sum(m[0] for m in target_mask)
|
||||
LOG.info(f"Number of positions in target_mask: {target_mask_positions}")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from datasets import disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@@ -95,41 +95,9 @@ def disable_datasets_caching():
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
"""
|
||||
Handle both single-example and batched data.
|
||||
- single example: sample['input_ids'] is a list[int]
|
||||
- batched data: sample['input_ids'] is a list[list[int]]
|
||||
"""
|
||||
# Return sample unchanged if "input_ids" is not present, or is empty
|
||||
if "input_ids" not in sample or not sample["input_ids"]:
|
||||
return sample
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# If first element is an int, it’s a single example
|
||||
# If first element is a list, it’s a batch
|
||||
if isinstance(input_ids[0], int):
|
||||
# ---- SINGLE EXAMPLE ----
|
||||
seq_len = len(input_ids)
|
||||
# Position IDs for a single example
|
||||
# As a list
|
||||
sample["position_ids"] = list(range(seq_len))
|
||||
sample["length"] = seq_len
|
||||
|
||||
else:
|
||||
# ---- BATCHED EXAMPLES ----
|
||||
# input_ids is a list of lists
|
||||
position_ids_batch = []
|
||||
lengths_batch = []
|
||||
for seq in input_ids:
|
||||
seq_len = len(seq)
|
||||
position_ids_batch.append(list(range(seq_len)))
|
||||
lengths_batch.append(seq_len)
|
||||
|
||||
# Now store them back
|
||||
sample["position_ids"] = position_ids_batch
|
||||
sample["length"] = lengths_batch
|
||||
|
||||
sample_len = len(sample["input_ids"])
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
sample["length"] = sample_len
|
||||
return sample
|
||||
|
||||
|
||||
@@ -204,31 +172,10 @@ def add_length(sample):
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
"""
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Edge case: if input_ids is empty
|
||||
if not input_ids:
|
||||
# Decide if you want to drop or keep empty. Let's drop.
|
||||
return False
|
||||
|
||||
# Check if single example or batched by looking at the first element
|
||||
if isinstance(input_ids[0], int):
|
||||
# Single example (input_ids is a list of int)
|
||||
length = len(input_ids)
|
||||
return min_sequence_len <= length <= sequence_len
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
length = len(seq)
|
||||
results.append(min_sequence_len <= length <= sequence_len)
|
||||
return results
|
||||
return (
|
||||
len(sample["input_ids"]) <= sequence_len
|
||||
and len(sample["input_ids"]) >= min_sequence_len
|
||||
)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
@@ -238,13 +185,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
except AttributeError:
|
||||
pass
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
|
||||
if cfg.model_config_type == "mamba":
|
||||
LOG.info("dropping attention_mask column")
|
||||
@@ -259,105 +203,59 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(train_dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
try:
|
||||
prior_len = len(train_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
prior_len = len(train_dataset)
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
||||
|
||||
if eval_dataset:
|
||||
try:
|
||||
prior_len = len(eval_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
prior_len = len(eval_dataset)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_long,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
||||
|
||||
# drop samples with where the number of elements with labels not equal to -100 is zero
|
||||
def drop_no_trainable_tokens(sample):
|
||||
"""
|
||||
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
||||
Works for both single-example or batched input.
|
||||
"""
|
||||
labels = sample["labels"]
|
||||
if not labels:
|
||||
return True
|
||||
return np.sum(np.array(sample["labels"]) != -100) > 0
|
||||
|
||||
# Check if single example or batch
|
||||
# If first element is an int, we assume a single example
|
||||
# If it's a list, we assume we're dealing with a batch
|
||||
if isinstance(labels[0], int):
|
||||
# Single example: return a single bool
|
||||
return np.any(labels != -100)
|
||||
|
||||
# Batched: 'labels' is a list of lists
|
||||
# Return a list of booleans, one per sub-list
|
||||
results = [np.any(row_labels != -100) for row_labels in labels]
|
||||
return results
|
||||
|
||||
try:
|
||||
prior_len = len(train_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||
prior_len = len(train_dataset)
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Drop Samples with Zero Trainable Tokens",
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from train dataset"
|
||||
)
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from train dataset"
|
||||
)
|
||||
|
||||
if eval_dataset:
|
||||
try:
|
||||
prior_len = len(eval_dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
prior_len = len(eval_dataset)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Drop Samples with Zero Trainable Tokens",
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
|
||||
)
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
|
||||
)
|
||||
|
||||
if cfg.group_by_length:
|
||||
train_dataset = train_dataset.map(
|
||||
@@ -393,21 +291,19 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (Sample Packing)",
|
||||
)
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (Sample Packing)",
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
@@ -441,7 +337,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
and not cfg.reward_model
|
||||
):
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.select_columns("input_ids")
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
base_config = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}
|
||||
|
||||
generate_sweep_configs(base_config, sweeps_config)
|
||||
|
||||
assert len(generate_sweep_configs(base_config, sweeps_config)) == 6
|
||||
|
||||
cfg_1 = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 2,
|
||||
"weight_decay": 0.0,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
assert any(
|
||||
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||
)
|
||||
|
||||
|
||||
def test_generate_sweep_configs_with_pairs():
|
||||
base_config = {
|
||||
"learning_rate": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sample_packing": True,
|
||||
}
|
||||
|
||||
sweeps_config = {
|
||||
"_": [
|
||||
{
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 8,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 4,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
},
|
||||
{
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
},
|
||||
],
|
||||
"weight_decay": [0.0, 0.1],
|
||||
}
|
||||
|
||||
generate_sweep_configs(base_config, sweeps_config)
|
||||
|
||||
assert len(generate_sweep_configs(base_config, sweeps_config)) == 8
|
||||
|
||||
assert all(
|
||||
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
|
||||
for cfg in generate_sweep_configs(base_config, sweeps_config)
|
||||
)
|
||||
@@ -1,121 +0,0 @@
|
||||
"""
|
||||
e2e tests for kd trainer support in Axolotl
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="kd_min_cfg")
|
||||
def min_cfg(temp_dir):
|
||||
return {
|
||||
"base_model": "osllmai-community/Llama-3.2-1B",
|
||||
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
"plugins": [
|
||||
"axolotl.integrations.kd.KDPlugin",
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
],
|
||||
"liger_rms_norm": True,
|
||||
"liger_glu_activation": True,
|
||||
"torch_compile": True,
|
||||
"chat_template": "llama3",
|
||||
"kd_trainer": True,
|
||||
"kd_ce_alpha": 0.1,
|
||||
"kd_alpha": 0.9,
|
||||
"kd_temperature": 1.0,
|
||||
"dataloader_prefetch_factor": 8,
|
||||
"dataloader_num_workers": 4,
|
||||
"dataloader_pin_memory": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
|
||||
"type": "axolotl.integrations.kd.chat_template",
|
||||
"field_messages": "messages_combined",
|
||||
"split": "train",
|
||||
"logprobs_field": "llm_text_generation_vllm_logprobs",
|
||||
"temperature": 1.0,
|
||||
"preprocess_shards": 2,
|
||||
},
|
||||
],
|
||||
"val_set_size": 0.0,
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.00001,
|
||||
"bf16": "auto",
|
||||
"gradient_checkpointing": True,
|
||||
"flash_attention": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|end_of_text|>",
|
||||
"eos_token": "<|eot_id|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"output_dir": temp_dir,
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
|
||||
|
||||
class TestKnowledgeDistillation:
|
||||
"""
|
||||
Test case for Knowledge Distillation
|
||||
"""
|
||||
|
||||
# While this will run on torch 2.4.x without torch_compile enabled
|
||||
# the VRAM requirement is higher than what is available in CI
|
||||
@require_torch_2_5_1
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
[True, False],
|
||||
)
|
||||
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": load_in_8bit,
|
||||
"torch_compile": False,
|
||||
"adapter": "lora",
|
||||
"peft_use_dora": True,
|
||||
"lora_target_linear": True,
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
@@ -55,7 +55,6 @@ class LigerIntegrationTestCase:
|
||||
"max_steps": 5,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -101,7 +100,6 @@ class LigerIntegrationTestCase:
|
||||
"max_steps": 5,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
|
||||
)
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -82,10 +82,7 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
Reference in New Issue
Block a user